Full Code of black-forest-labs/flux for AI

main 802fb4713906 cached
51 files
319.8 KB
78.9k tokens
266 symbols
1 requests
Download .txt
Showing preview only (336K chars total). Download the full file or copy to clipboard to get everything.
Repository: black-forest-labs/flux
Branch: main
Commit: 802fb4713906
Files: 51
Total size: 319.8 KB

Directory structure:
gitextract_5wlz5y8p/

├── .github/
│   └── workflows/
│       └── ci.yaml
├── .gitignore
├── LICENSE
├── README.md
├── demo_gr.py
├── demo_st.py
├── demo_st_fill.py
├── docs/
│   ├── fill.md
│   ├── image-editing.md
│   ├── image-variation.md
│   ├── structural-conditioning.md
│   └── text-to-image.md
├── model_cards/
│   ├── FLUX.1-Krea-dev.md
│   ├── FLUX.1-dev.md
│   ├── FLUX.1-kontext-dev.md
│   └── FLUX.1-schnell.md
├── model_licenses/
│   ├── LICENSE-FLUX1-dev
│   └── LICENSE-FLUX1-schnell
├── pyproject.toml
├── setup.py
└── src/
    └── flux/
        ├── __init__.py
        ├── __main__.py
        ├── cli.py
        ├── cli_control.py
        ├── cli_fill.py
        ├── cli_kontext.py
        ├── cli_redux.py
        ├── content_filters.py
        ├── math.py
        ├── model.py
        ├── modules/
        │   ├── autoencoder.py
        │   ├── conditioner.py
        │   ├── image_embedders.py
        │   ├── layers.py
        │   └── lora.py
        ├── sampling.py
        ├── trt/
        │   ├── __init__.py
        │   ├── engine/
        │   │   ├── __init__.py
        │   │   ├── base_engine.py
        │   │   ├── clip_engine.py
        │   │   ├── t5_engine.py
        │   │   ├── transformer_engine.py
        │   │   └── vae_engine.py
        │   ├── trt_config/
        │   │   ├── __init__.py
        │   │   ├── base_trt_config.py
        │   │   ├── clip_trt_config.py
        │   │   ├── t5_trt_config.py
        │   │   ├── transformer_trt_config.py
        │   │   └── vae_trt_config.py
        │   └── trt_manager.py
        └── util.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .github/workflows/ci.yaml
================================================
name: CI
on: push
jobs:
  lint:
    runs-on: ubuntu-latest
    steps:
      - uses: actions/checkout@v2
      - uses: actions/setup-python@v2
        with:
          python-version: "3.10"
      - name: Install dependencies
        run: |
          python -m pip install --upgrade pip
          pip install ruff==0.6.8
      - name: Run Ruff
        run: ruff check --output-format=github .
      - name: Check imports
        run: ruff check --select I --output-format=github .
      - name: Check formatting
        run: ruff format --check .


================================================
FILE: .gitignore
================================================
# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python

### Linux ###
*~

# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*

# KDE directory preferences
.directory

# Linux trash folder which might appear on any partition or disk
.Trash-*

# .nfs files are created when an open file is removed but is still being accessed
.nfs*

### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride

# Icon must end with two \r
Icon


# Thumbnails
._*

# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent

# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
#  Usually these files are written by a python script from a template
#  before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
#   For a library or package, you might want to ignore these files since the code is
#   intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
#   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
#   However, in case of collaboration, if having platform-specific dependencies or dependencies
#   having no cross-platform support, pipenv may install dependencies that don't work, or not
#   install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace

# Local History for Visual Studio Code
.history/

### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide

### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db

# Dump file
*.stackdump

# Folder config file
[Dd]esktop.ini

# Recycle Bin used on file shares
$RECYCLE.BIN/

# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp

# Windows shortcuts
*.lnk

# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python

output/


================================================
FILE: LICENSE
================================================
                                 Apache License
                           Version 2.0, January 2004
                        http://www.apache.org/licenses/

   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

   1. Definitions.

      "License" shall mean the terms and conditions for use, reproduction,
      and distribution as defined by Sections 1 through 9 of this document.

      "Licensor" shall mean the copyright owner or entity authorized by
      the copyright owner that is granting the License.

      "Legal Entity" shall mean the union of the acting entity and all
      other entities that control, are controlled by, or are under common
      control with that entity. For the purposes of this definition,
      "control" means (i) the power, direct or indirect, to cause the
      direction or management of such entity, whether by contract or
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
      outstanding shares, or (iii) beneficial ownership of such entity.

      "You" (or "Your") shall mean an individual or Legal Entity
      exercising permissions granted by this License.

      "Source" form shall mean the preferred form for making modifications,
      including but not limited to software source code, documentation
      source, and configuration files.

      "Object" form shall mean any form resulting from mechanical
      transformation or translation of a Source form, including but
      not limited to compiled object code, generated documentation,
      and conversions to other media types.

      "Work" shall mean the work of authorship, whether in Source or
      Object form, made available under the License, as indicated by a
      copyright notice that is included in or attached to the work
      (an example is provided in the Appendix below).

      "Derivative Works" shall mean any work, whether in Source or Object
      form, that is based on (or derived from) the Work and for which the
      editorial revisions, annotations, elaborations, or other modifications
      represent, as a whole, an original work of authorship. For the purposes
      of this License, Derivative Works shall not include works that remain
      separable from, or merely link (or bind by name) to the interfaces of,
      the Work and Derivative Works thereof.

      "Contribution" shall mean any work of authorship, including
      the original version of the Work and any modifications or additions
      to that Work or Derivative Works thereof, that is intentionally
      submitted to Licensor for inclusion in the Work by the copyright owner
      or by an individual or Legal Entity authorized to submit on behalf of
      the copyright owner. For the purposes of this definition, "submitted"
      means any form of electronic, verbal, or written communication sent
      to the Licensor or its representatives, including but not limited to
      communication on electronic mailing lists, source code control systems,
      and issue tracking systems that are managed by, or on behalf of, the
      Licensor for the purpose of discussing and improving the Work, but
      excluding communication that is conspicuously marked or otherwise
      designated in writing by the copyright owner as "Not a Contribution."

      "Contributor" shall mean Licensor and any individual or Legal Entity
      on behalf of whom a Contribution has been received by Licensor and
      subsequently incorporated within the Work.

   2. Grant of Copyright License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      copyright license to reproduce, prepare Derivative Works of,
      publicly display, publicly perform, sublicense, and distribute the
      Work and such Derivative Works in Source or Object form.

   3. Grant of Patent License. Subject to the terms and conditions of
      this License, each Contributor hereby grants to You a perpetual,
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
      (except as stated in this section) patent license to make, have made,
      use, offer to sell, sell, import, and otherwise transfer the Work,
      where such license applies only to those patent claims licensable
      by such Contributor that are necessarily infringed by their
      Contribution(s) alone or by combination of their Contribution(s)
      with the Work to which such Contribution(s) was submitted. If You
      institute patent litigation against any entity (including a
      cross-claim or counterclaim in a lawsuit) alleging that the Work
      or a Contribution incorporated within the Work constitutes direct
      or contributory patent infringement, then any patent licenses
      granted to You under this License for that Work shall terminate
      as of the date such litigation is filed.

   4. Redistribution. You may reproduce and distribute copies of the
      Work or Derivative Works thereof in any medium, with or without
      modifications, and in Source or Object form, provided that You
      meet the following conditions:

      (a) You must give any other recipients of the Work or
          Derivative Works a copy of this License; and

      (b) You must cause any modified files to carry prominent notices
          stating that You changed the files; and

      (c) You must retain, in the Source form of any Derivative Works
          that You distribute, all copyright, patent, trademark, and
          attribution notices from the Source form of the Work,
          excluding those notices that do not pertain to any part of
          the Derivative Works; and

      (d) If the Work includes a "NOTICE" text file as part of its
          distribution, then any Derivative Works that You distribute must
          include a readable copy of the attribution notices contained
          within such NOTICE file, excluding those notices that do not
          pertain to any part of the Derivative Works, in at least one
          of the following places: within a NOTICE text file distributed
          as part of the Derivative Works; within the Source form or
          documentation, if provided along with the Derivative Works; or,
          within a display generated by the Derivative Works, if and
          wherever such third-party notices normally appear. The contents
          of the NOTICE file are for informational purposes only and
          do not modify the License. You may add Your own attribution
          notices within Derivative Works that You distribute, alongside
          or as an addendum to the NOTICE text from the Work, provided
          that such additional attribution notices cannot be construed
          as modifying the License.

      You may add Your own copyright statement to Your modifications and
      may provide additional or different license terms and conditions
      for use, reproduction, or distribution of Your modifications, or
      for any such Derivative Works as a whole, provided Your use,
      reproduction, and distribution of the Work otherwise complies with
      the conditions stated in this License.

   5. Submission of Contributions. Unless You explicitly state otherwise,
      any Contribution intentionally submitted for inclusion in the Work
      by You to the Licensor shall be under the terms and conditions of
      this License, without any additional terms or conditions.
      Notwithstanding the above, nothing herein shall supersede or modify
      the terms of any separate license agreement you may have executed
      with Licensor regarding such Contributions.

   6. Trademarks. This License does not grant permission to use the trade
      names, trademarks, service marks, or product names of the Licensor,
      except as required for reasonable and customary use in describing the
      origin of the Work and reproducing the content of the NOTICE file.

   7. Disclaimer of Warranty. Unless required by applicable law or
      agreed to in writing, Licensor provides the Work (and each
      Contributor provides its Contributions) on an "AS IS" BASIS,
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
      implied, including, without limitation, any warranties or conditions
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
      PARTICULAR PURPOSE. You are solely responsible for determining the
      appropriateness of using or redistributing the Work and assume any
      risks associated with Your exercise of permissions under this License.

   8. Limitation of Liability. In no event and under no legal theory,
      whether in tort (including negligence), contract, or otherwise,
      unless required by applicable law (such as deliberate and grossly
      negligent acts) or agreed to in writing, shall any Contributor be
      liable to You for damages, including any direct, indirect, special,
      incidental, or consequential damages of any character arising as a
      result of this License or out of the use or inability to use the
      Work (including but not limited to damages for loss of goodwill,
      work stoppage, computer failure or malfunction, or any and all
      other commercial damages or losses), even if such Contributor
      has been advised of the possibility of such damages.

   9. Accepting Warranty or Additional Liability. While redistributing
      the Work or Derivative Works thereof, You may choose to offer,
      and charge a fee for, acceptance of support, warranty, indemnity,
      or other liability obligations and/or rights consistent with this
      License. However, in accepting such obligations, You may act only
      on Your own behalf and on Your sole responsibility, not on behalf
      of any other Contributor, and only if You agree to indemnify,
      defend, and hold each Contributor harmless for any liability
      incurred by, or claims asserted against, such Contributor by reason
      of your accepting any such warranty or additional liability.

   END OF TERMS AND CONDITIONS

   APPENDIX: How to apply the Apache License to your work.

      To apply the Apache License to your work, attach the following
      boilerplate notice, with the fields enclosed by brackets "[]"
      replaced with your own identifying information. (Don't include
      the brackets!)  The text should be enclosed in the appropriate
      comment syntax for the file format. We also recommend that a
      file or class name and description of purpose be included on the
      same "printed page" as the copyright notice for easier
      identification within third-party archives.

   Copyright [yyyy] [name of copyright owner]

   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at

       http://www.apache.org/licenses/LICENSE-2.0

   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.


================================================
FILE: README.md
================================================
# FLUX
by Black Forest Labs: https://bfl.ai.

Documentation for our API can be found here: [docs.bfl.ai](https://docs.bfl.ai/).

![grid](assets/grid.jpg)

This repo contains minimal inference code to run image generation & editing with our Flux open-weight models.

## Local installation

```bash
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux
python3.10 -m venv .venv
source .venv/bin/activate
pip install -e ".[all]"
```

### Local installation with TensorRT support

If you would like to install the repository with [TensorRT](https://github.com/NVIDIA/TensorRT) support, you currently need to install a PyTorch image from NVIDIA instead. First install [enroot](https://github.com/NVIDIA/enroot), next follow the steps below:

```bash
cd $HOME && git clone https://github.com/black-forest-labs/flux
enroot import 'docker://$oauthtoken@nvcr.io#nvidia/pytorch:25.01-py3'
enroot create -n pti2501 nvidia+pytorch+25.01-py3.sqsh
enroot start --rw -m ${PWD}/flux:/workspace/flux -r pti2501
cd flux
pip install -e ".[tensorrt]" --extra-index-url https://pypi.nvidia.com
```

### Open-weight models

We are offering an extensive suite of open-weight models. For more information about the individual models, please refer to the link under **Usage**.

| Name                        | Usage                                                      | HuggingFace repo                                               | License                                                               |
| --------------------------- | ---------------------------------------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- |
| `FLUX.1 [schnell]`          | [Text to Image](docs/text-to-image.md)                     | https://huggingface.co/black-forest-labs/FLUX.1-schnell        | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell)                    |
| `FLUX.1 [dev]`              | [Text to Image](docs/text-to-image.md)                     | https://huggingface.co/black-forest-labs/FLUX.1-dev            | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Fill [dev]`         | [In/Out-painting](docs/fill.md)                            | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev       | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Canny [dev]`        | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev      | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Depth [dev]`        | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev      | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Canny [dev] LoRA`   | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Depth [dev] LoRA`   | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Redux [dev]`        | [Image variation](docs/image-variation.md)                 | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev      | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Kontext [dev]`      | [Image editing](docs/image-editing.md)                     | https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev    | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Krea [dev]`         | [Text to Image](docs/text-to-image.md)                     | https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev       | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |

The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.

## API usage

Our API offers access to all models including our Pro tier non-open weight models. Check out our API documentation [docs.bfl.ai](https://docs.bfl.ai/) to learn more.

## Licensing models for commercial use

You can license our models for commercial use here: https://bfl.ai/pricing/licensing

As the fee is based on a monthly usage, we provide code to automatically track your usage via the BFL API. To enable usage tracking please select *track_usage* in the cli or click the corresponding checkmark in our provided demos.

### Example: Using FLUX.1 Kontext with usage tracking

We provide a reference implementation for running FLUX.1 with usage tracking enabled for commercial licensing.
This can be customized as needed as long as the usage reporting is accurate.

For the reporting logic to work you will need to set your API key as an environment variable before running:
```bash
export BFL_API_KEY="your_api_key_here"
```

You can call `FLUX.1 Kontext [dev]` like this with tracking activated:

```bash
python -m flux kontext --track_usage --loop
```

For a single generation:

```bash
python -m flux kontext --track_usage --prompt "replace the logo with the text 'Black Forest Labs'"
```

The above reporting logic works similarly for FLUX.1 [dev] and FLUX.1 Tools [dev].

**Note that this is only required when using one or more of our open weights models commercially. More information on the commercial licensing can be found at the [BFL Helpdesk](https://help.bfl.ai/collections/6939000511-licensing).**


## Citation

If you find the provided code or models useful for your research, consider citing them as:

```bib
@misc{labs2025flux1kontextflowmatching,
      title={FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space},
      author={Black Forest Labs and Stephen Batifol and Andreas Blattmann and Frederic Boesel and Saksham Consul and Cyril Diagne and Tim Dockhorn and Jack English and Zion English and Patrick Esser and Sumith Kulal and Kyle Lacey and Yam Levi and Cheng Li and Dominik Lorenz and Jonas Müller and Dustin Podell and Robin Rombach and Harry Saini and Axel Sauer and Luke Smith},
      year={2025},
      eprint={2506.15742},
      archivePrefix={arXiv},
      primaryClass={cs.GR},
      url={https://arxiv.org/abs/2506.15742},
}

@misc{flux2024,
    author={Black Forest Labs},
    title={FLUX},
    year={2024},
    howpublished={\url{https://github.com/black-forest-labs/flux}},
}
```


================================================
FILE: demo_gr.py
================================================
import os
import time
import uuid

import gradio as gr
import numpy as np
import torch
from einops import rearrange
from PIL import ExifTags, Image
from transformers import pipeline

from flux.cli import SamplingOptions
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
    configs,
    embed_watermark,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    track_usage_via_api,
)

NSFW_THRESHOLD = 0.85


def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
    t5 = load_t5(device, max_length=256 if is_schnell else 512)
    clip = load_clip(device)
    model = load_flow_model(name, device="cpu" if offload else device)
    ae = load_ae(name, device="cpu" if offload else device)
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
    return model, ae, t5, clip, nsfw_classifier


class FluxGenerator:
    def __init__(self, model_name: str, device: str, offload: bool, track_usage: bool):
        self.device = torch.device(device)
        self.offload = offload
        self.model_name = model_name
        self.is_schnell = model_name == "flux-schnell"
        self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
            model_name,
            device=self.device,
            offload=self.offload,
            is_schnell=self.is_schnell,
        )
        self.track_usage = track_usage

    @torch.inference_mode()
    def generate_image(
        self,
        width,
        height,
        num_steps,
        guidance,
        seed,
        prompt,
        init_image=None,
        image2image_strength=0.0,
        add_sampling_metadata=True,
    ):
        seed = int(seed)
        if seed == -1:
            seed = None

        opts = SamplingOptions(
            prompt=prompt,
            width=width,
            height=height,
            num_steps=num_steps,
            guidance=guidance,
            seed=seed,
        )

        if opts.seed is None:
            opts.seed = torch.Generator(device="cpu").seed()
        print(f"Generating '{opts.prompt}' with seed {opts.seed}")
        t0 = time.perf_counter()

        if init_image is not None:
            if isinstance(init_image, np.ndarray):
                init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
                init_image = init_image.unsqueeze(0)
            init_image = init_image.to(self.device)
            init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
            if self.offload:
                self.ae.encoder.to(self.device)
            init_image = self.ae.encode(init_image.to())
            if self.offload:
                self.ae = self.ae.cpu()
                torch.cuda.empty_cache()

        # prepare input
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=self.device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        timesteps = get_schedule(
            opts.num_steps,
            x.shape[-1] * x.shape[-2] // 4,
            shift=(not self.is_schnell),
        )
        if init_image is not None:
            t_idx = int((1 - image2image_strength) * num_steps)
            t = timesteps[t_idx]
            timesteps = timesteps[t_idx:]
            x = t * x + (1.0 - t) * init_image.to(x.dtype)

        if self.offload:
            self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
        inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)

        # offload TEs to CPU, load model to gpu
        if self.offload:
            self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
            torch.cuda.empty_cache()
            self.model = self.model.to(self.device)

        # denoise initial noise
        x = denoise(self.model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # offload model, load autoencoder to gpu
        if self.offload:
            self.model.cpu()
            torch.cuda.empty_cache()
            self.ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
            x = self.ae.decode(x)

        if self.offload:
            self.ae.decoder.cpu()
            torch.cuda.empty_cache()

        t1 = time.perf_counter()

        print(f"Done in {t1 - t0:.1f}s.")
        # bring into PIL format
        x = x.clamp(-1, 1)
        x = embed_watermark(x.float())
        x = rearrange(x[0], "c h w -> h w c")

        img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
        nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0]

        if nsfw_score < NSFW_THRESHOLD:
            filename = f"output/gradio/{uuid.uuid4()}.jpg"
            os.makedirs(os.path.dirname(filename), exist_ok=True)
            exif_data = Image.Exif()
            if init_image is None:
                exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
            else:
                exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
            exif_data[ExifTags.Base.Make] = "Black Forest Labs"
            exif_data[ExifTags.Base.Model] = self.model_name
            if add_sampling_metadata:
                exif_data[ExifTags.Base.ImageDescription] = prompt
            img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)

            if self.track_usage:
                track_usage_via_api(self.model_name, 1)

            return img, str(opts.seed), filename, None
        else:
            return None, str(opts.seed), None, "Your generated image may contain NSFW content."


def create_demo(
    model_name: str,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    offload: bool = False,
    track_usage: bool = False,
):
    generator = FluxGenerator(model_name, device, offload, track_usage)
    is_schnell = model_name == "flux-schnell"

    with gr.Blocks() as demo:
        gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}")

        with gr.Row():
            with gr.Column():
                prompt = gr.Textbox(
                    label="Prompt",
                    value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
                )
                do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
                init_image = gr.Image(label="Input Image", visible=False)
                image2image_strength = gr.Slider(
                    0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
                )

                with gr.Accordion("Advanced Options", open=False):
                    width = gr.Slider(128, 8192, 1360, step=16, label="Width")
                    height = gr.Slider(128, 8192, 768, step=16, label="Height")
                    num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")
                    guidance = gr.Slider(
                        1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell
                    )
                    seed = gr.Textbox(-1, label="Seed (-1 for random)")
                    add_sampling_metadata = gr.Checkbox(
                        label="Add sampling parameters to metadata?", value=True
                    )

                generate_btn = gr.Button("Generate")

            with gr.Column():
                output_image = gr.Image(label="Generated Image")
                seed_output = gr.Number(label="Used Seed")
                warning_text = gr.Textbox(label="Warning", visible=False)
                download_btn = gr.File(label="Download full-resolution")

        def update_img2img(do_img2img):
            return {
                init_image: gr.update(visible=do_img2img),
                image2image_strength: gr.update(visible=do_img2img),
            }

        do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])

        generate_btn.click(
            fn=generator.generate_image,
            inputs=[
                width,
                height,
                num_steps,
                guidance,
                seed,
                prompt,
                init_image,
                image2image_strength,
                add_sampling_metadata,
            ],
            outputs=[output_image, seed_output, download_btn, warning_text],
        )

    return demo


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Flux")
    parser.add_argument(
        "--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name"
    )
    parser.add_argument(
        "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use"
    )
    parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
    parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
    parser.add_argument("--track_usage", action="store_true", help="Track usage for licensing purposes")
    args = parser.parse_args()

    demo = create_demo(args.name, args.device, args.offload, args.track_usage)
    demo.launch(share=args.share)


================================================
FILE: demo_st.py
================================================
import os
import re
import time
from glob import iglob
from io import BytesIO

import streamlit as st
import torch
from einops import rearrange
from fire import Fire
from PIL import ExifTags, Image
from st_keyup import st_keyup
from torchvision import transforms
from transformers import pipeline

from flux.cli import SamplingOptions
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
    configs,
    embed_watermark,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    track_usage_via_api,
)

NSFW_THRESHOLD = 0.85


@st.cache_resource()
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
    t5 = load_t5(device, max_length=256 if is_schnell else 512)
    clip = load_clip(device)
    model = load_flow_model(name, device="cpu" if offload else device)
    ae = load_ae(name, device="cpu" if offload else device)
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
    return model, ae, t5, clip, nsfw_classifier


def get_image() -> torch.Tensor | None:
    image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
    if image is None:
        return None
    image = Image.open(image).convert("RGB")

    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Lambda(lambda x: 2.0 * x - 1.0),
        ]
    )
    img: torch.Tensor = transform(image)
    return img[None, ...]


@torch.inference_mode()
def main(
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    offload: bool = False,
    output_dir: str = "output",
    track_usage: bool = False,
):
    torch_device = torch.device(device)
    names = list(configs.keys())
    name = st.selectbox("Which model to load?", names)
    if name is None or not st.checkbox("Load model", False):
        return

    is_schnell = name == "flux-schnell"
    model, ae, t5, clip, nsfw_classifier = get_models(
        name,
        device=torch_device,
        offload=offload,
        is_schnell=is_schnell,
    )

    do_img2img = (
        st.checkbox(
            "Image to Image",
            False,
            disabled=is_schnell,
            help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev",
        )
        and not is_schnell
    )
    if do_img2img:
        init_image = get_image()
        if init_image is None:
            st.warning("Please add an image to do image to image")
        image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8)
        if init_image is not None:
            h, w = init_image.shape[-2:]
            st.write(f"Got image of size {w}x{h} ({h * w / 1e6:.2f}MP)")
        resize_img = st.checkbox("Resize image", False) or init_image is None
    else:
        init_image = None
        resize_img = True
        image2image_strength = 0.0

    # allow for packing and conversion to latent space
    width = int(
        16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16)
    )
    height = int(
        16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16)
    )
    num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50)))
    guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell))
    seed_str = st.text_input("Seed", disabled=is_schnell)
    if seed_str.isdecimal():
        seed = int(seed_str)
    else:
        st.info("No seed set, set to positive integer to enable")
        seed = None
    save_samples = st.checkbox("Save samples?", not is_schnell)
    add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)

    default_prompt = (
        "a photo of a forest with mist swirling around the tree trunks. The word "
        '"FLUX" is painted over it in big, red brush strokes with visible texture'
    )
    prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text")

    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    rng = torch.Generator(device="cpu")

    if "seed" not in st.session_state:
        st.session_state.seed = rng.seed()

    def increment_counter():
        st.session_state.seed += 1

    def decrement_counter():
        if st.session_state.seed > 0:
            st.session_state.seed -= 1

    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
    )

    if name == "flux-schnell":
        cols = st.columns([5, 1, 1, 5])
        with cols[1]:
            st.button("↩", on_click=increment_counter)
        with cols[2]:
            st.button("↪", on_click=decrement_counter)
    if is_schnell or st.button("Sample"):
        if is_schnell:
            opts.seed = st.session_state.seed
        elif opts.seed is None:
            opts.seed = rng.seed()
        print(f"Generating '{opts.prompt}' with seed {opts.seed}")
        t0 = time.perf_counter()

        if init_image is not None:
            if resize_img:
                init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
            else:
                h, w = init_image.shape[-2:]
                init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
                opts.height = init_image.shape[-2]
                opts.width = init_image.shape[-1]
            if offload:
                ae.encoder.to(torch_device)
            init_image = ae.encode(init_image.to(torch_device))
            if offload:
                ae = ae.cpu()
                torch.cuda.empty_cache()

        # prepare input
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=torch_device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        # divide pixel space by 16**2 to account for latent space conversion
        timesteps = get_schedule(
            opts.num_steps,
            (x.shape[-1] * x.shape[-2]) // 4,
            shift=(not is_schnell),
        )
        if init_image is not None:
            t_idx = int((1 - image2image_strength) * num_steps)
            t = timesteps[t_idx]
            timesteps = timesteps[t_idx:]
            x = t * x + (1.0 - t) * init_image.to(x.dtype)

        if offload:
            t5, clip = t5.to(torch_device), clip.to(torch_device)
        inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt)

        # offload TEs to CPU, load model to gpu
        if offload:
            t5, clip = t5.cpu(), clip.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # denoise initial noise
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # offload model, load autoencoder to gpu
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)

        if offload:
            ae.decoder.cpu()
            torch.cuda.empty_cache()

        t1 = time.perf_counter()

        fn = output_name.format(idx=idx)
        print(f"Done in {t1 - t0:.1f}s.")
        # bring into PIL format and save
        x = x.clamp(-1, 1)
        x = embed_watermark(x.float())
        x = rearrange(x[0], "c h w -> h w c")

        img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
        nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]

        if nsfw_score < NSFW_THRESHOLD:
            buffer = BytesIO()
            exif_data = Image.Exif()
            if init_image is None:
                exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
            else:
                exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
            exif_data[ExifTags.Base.Make] = "Black Forest Labs"
            exif_data[ExifTags.Base.Model] = name
            if add_sampling_metadata:
                exif_data[ExifTags.Base.ImageDescription] = prompt
            img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0)

            img_bytes = buffer.getvalue()
            if save_samples:
                print(f"Saving {fn}")
                with open(fn, "wb") as file:
                    file.write(img_bytes)
                idx += 1
            if track_usage:
                track_usage_via_api(name, 1)

            st.session_state["samples"] = {
                "prompt": opts.prompt,
                "img": img,
                "seed": opts.seed,
                "bytes": img_bytes,
            }
            opts.seed = None
        else:
            st.warning("Your generated image may contain NSFW content.")
            st.session_state["samples"] = None

    samples = st.session_state.get("samples", None)
    if samples is not None:
        st.image(samples["img"], caption=samples["prompt"])
        st.download_button(
            "Download full-resolution",
            samples["bytes"],
            file_name="generated.jpg",
            mime="image/jpg",
        )
        st.write(f"Seed: {samples['seed']}")


def app():
    Fire(main)


if __name__ == "__main__":
    app()


================================================
FILE: demo_st_fill.py
================================================
import os
import re
import tempfile
import time
from glob import iglob
from io import BytesIO

import numpy as np
import streamlit as st
import torch
from einops import rearrange
from fire import Fire
from PIL import ExifTags, Image
from st_keyup import st_keyup
from streamlit_drawable_canvas import st_canvas
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
from flux.util import (
    embed_watermark,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    track_usage_via_api,
)

NSFW_THRESHOLD = 0.85


def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, zoom_up=0, zoom_down=0, overlap=0):
    """Adds a black border around the image with individual side control and mask overlap"""
    orig_width, orig_height = image.size

    # Calculate padding for each side (in pixels)
    left_pad = int(orig_width * zoom_left)
    right_pad = int(orig_width * zoom_right)
    top_pad = int(orig_height * zoom_up)
    bottom_pad = int(orig_height * zoom_down)

    # Calculate overlap in pixels
    overlap_left = int(orig_width * overlap)
    overlap_right = int(orig_width * overlap)
    overlap_top = int(orig_height * overlap)
    overlap_bottom = int(orig_height * overlap)

    # If using the all-sides zoom, add it to each side
    if zoom_all > 1.0:
        extra_each_side = (zoom_all - 1.0) / 2
        left_pad += int(orig_width * extra_each_side)
        right_pad += int(orig_width * extra_each_side)
        top_pad += int(orig_height * extra_each_side)
        bottom_pad += int(orig_height * extra_each_side)

    # Calculate new dimensions (ensure they're multiples of 32)
    new_width = 32 * round((orig_width + left_pad + right_pad) / 32)
    new_height = 32 * round((orig_height + top_pad + bottom_pad) / 32)

    # Create new image with black border
    bordered_image = Image.new("RGB", (new_width, new_height), (0, 0, 0))
    # Paste original image in position
    paste_x = left_pad
    paste_y = top_pad
    bordered_image.paste(image, (paste_x, paste_y))

    # Create mask (white where the border is, black where the original image was)
    mask = Image.new("L", (new_width, new_height), 255)  # White background
    # Paste black rectangle with overlap adjustment
    mask.paste(
        0,
        (
            paste_x + overlap_left,  # Left edge moves right
            paste_y + overlap_top,  # Top edge moves down
            paste_x + orig_width - overlap_right,  # Right edge moves left
            paste_y + orig_height - overlap_bottom,  # Bottom edge moves up
        ),
    )

    return bordered_image, mask


@st.cache_resource()
def get_models(name: str, device: torch.device, offload: bool):
    t5 = load_t5(device, max_length=128)
    clip = load_clip(device)
    model = load_flow_model(name, device="cpu" if offload else device)
    ae = load_ae(name, device="cpu" if offload else device)
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
    return model, ae, t5, clip, nsfw_classifier


def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -> Image.Image:
    width, height = img.size
    mp = (width * height) / 1_000_000  # Current megapixels

    if min_mp <= mp <= max_mp:
        # Even if MP is in range, ensure dimensions are multiples of 32
        new_width = int(32 * round(width / 32))
        new_height = int(32 * round(height / 32))
        if new_width != width or new_height != height:
            return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
        return img

    # Calculate scaling factor
    if mp < min_mp:
        scale = (min_mp / mp) ** 0.5
    else:  # mp > max_mp
        scale = (max_mp / mp) ** 0.5

    new_width = int(32 * round(width * scale / 32))
    new_height = int(32 * round(height * scale / 32))

    return img.resize((new_width, new_height), Image.Resampling.LANCZOS)


def clear_canvas_state():
    """Clear all canvas-related state"""
    keys_to_clear = ["canvas", "last_image_dims"]
    for key in keys_to_clear:
        if key in st.session_state:
            del st.session_state[key]


def set_new_image(img: Image.Image):
    """Safely set a new image and clear relevant state"""
    st.session_state["current_image"] = img
    clear_canvas_state()
    st.rerun()


def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image:
    """Downscale image by a given factor while maintaining 32-pixel multiple dimensions"""
    if scale_factor >= 1.0:
        return img

    width, height = img.size
    new_width = int(32 * round(width * scale_factor / 32))
    new_height = int(32 * round(height * scale_factor / 32))

    # Ensure minimum dimensions
    new_width = max(64, new_width)  # minimum 64 pixels
    new_height = max(64, new_height)  # minimum 64 pixels

    return img.resize((new_width, new_height), Image.Resampling.LANCZOS)


@torch.inference_mode()
def main(
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    offload: bool = False,
    output_dir: str = "output",
    track_usage: bool = False,
):
    torch_device = torch.device(device)
    st.title("Flux Fill: Inpainting & Outpainting")

    # Model selection and loading
    name = "flux-dev-fill"
    if not st.checkbox("Load model", False):
        return

    try:
        model, ae, t5, clip, nsfw_classifier = get_models(
            name,
            device=torch_device,
            offload=offload,
        )
    except Exception as e:
        st.error(f"Error loading models: {e}")
        return

    # Mode selection
    mode = st.radio("Select Mode", ["Inpainting", "Outpainting"])

    # Image handling - either from previous generation or new upload
    if "input_image" in st.session_state:
        image = st.session_state["input_image"]
        del st.session_state["input_image"]
        set_new_image(image)
        st.write("Continuing from previous result")
    else:
        uploaded_image = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"])
        if uploaded_image is None:
            st.warning("Please upload an image")
            return

        if (
            "current_image_name" not in st.session_state
            or st.session_state["current_image_name"] != uploaded_image.name
        ):
            try:
                image = Image.open(uploaded_image).convert("RGB")
                st.session_state["current_image_name"] = uploaded_image.name
                set_new_image(image)
            except Exception as e:
                st.error(f"Error loading image: {e}")
                return
        else:
            image = st.session_state.get("current_image")
            if image is None:
                st.error("Error: Image state is invalid. Please reupload the image.")
                clear_canvas_state()
                return

    # Add downscale control
    with st.expander("Image Size Control"):
        current_mp = (image.size[0] * image.size[1]) / 1_000_000
        st.write(f"Current image size: {image.size[0]}x{image.size[1]} ({current_mp:.1f}MP)")

        scale_factor = st.slider(
            "Downscale Factor",
            min_value=0.1,
            max_value=1.0,
            value=1.0,
            step=0.1,
            help="1.0 = original size, 0.5 = half size, etc.",
        )

        if scale_factor < 1.0 and st.button("Apply Downscaling"):
            image = downscale_image(image, scale_factor)
            set_new_image(image)
            st.rerun()

    # Resize image with validation
    try:
        original_mp = (image.size[0] * image.size[1]) / 1_000_000
        image = resize(image)
        width, height = image.size
        current_mp = (width * height) / 1_000_000

        if width % 32 != 0 or height % 32 != 0:
            st.error("Error: Image dimensions must be multiples of 32")
            return

        st.write(f"Image dimensions: {width}x{height} pixels")
        if original_mp != current_mp:
            st.write(
                f"Image has been resized from {original_mp:.1f}MP to {current_mp:.1f}MP to stay within bounds (0.5MP - 2MP)"
            )
    except Exception as e:
        st.error(f"Error processing image: {e}")
        return

    if mode == "Outpainting":
        # Outpainting controls
        zoom_all = st.slider("Zoom Out Amount (All Sides)", min_value=1.0, max_value=3.0, value=1.0, step=0.1)

        with st.expander("Advanced Zoom Controls"):
            st.info("These controls add additional zoom to specific sides")
            col1, col2 = st.columns(2)
            with col1:
                zoom_left = st.slider("Left", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
                zoom_right = st.slider("Right", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
            with col2:
                zoom_up = st.slider("Up", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
                zoom_down = st.slider("Down", min_value=0.0, max_value=1.0, value=0.0, step=0.1)

        overlap = st.slider("Overlap", min_value=0.01, max_value=0.25, value=0.01, step=0.01)

        # Generate bordered image and mask
        image_for_generation, mask = add_border_and_mask(
            image,
            zoom_all=zoom_all,
            zoom_left=zoom_left,
            zoom_right=zoom_right,
            zoom_up=zoom_up,
            zoom_down=zoom_down,
            overlap=overlap,
        )
        width, height = image_for_generation.size

        # Show preview
        col1, col2 = st.columns(2)
        with col1:
            st.image(image_for_generation, caption="Image with Border")
        with col2:
            st.image(mask, caption="Mask (white areas will be generated)")

    else:  # Inpainting mode
        # Canvas setup with dimension tracking
        canvas_key = f"canvas_{width}_{height}"
        if "last_image_dims" not in st.session_state:
            st.session_state.last_image_dims = (width, height)
        elif st.session_state.last_image_dims != (width, height):
            clear_canvas_state()
            st.session_state.last_image_dims = (width, height)
            st.rerun()

        try:
            canvas_result = st_canvas(
                fill_color="rgba(255, 255, 255, 0.0)",
                stroke_width=st.slider("Brush size", 1, 500, 50),
                stroke_color="#fff",
                background_image=image,
                height=height,
                width=width,
                drawing_mode="freedraw",
                key=canvas_key,
                display_toolbar=True,
            )
        except Exception as e:
            st.error(f"Error creating canvas: {e}")
            clear_canvas_state()
            st.rerun()
            return

    # Sampling parameters
    num_steps = int(st.number_input("Number of steps", min_value=1, value=50))
    guidance = float(st.number_input("Guidance", min_value=1.0, value=30.0))
    seed_str = st.text_input("Seed")
    if seed_str.isdecimal():
        seed = int(seed_str)
    else:
        st.info("No seed set, using random seed")
        seed = None

    save_samples = st.checkbox("Save samples?", True)
    add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)

    # Prompt input
    prompt = st_keyup("Enter a prompt", value="", debounce=300, key="interactive_text")

    # Setup output path
    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        idx = len(fns)

    if st.button("Generate"):
        valid_input = False

        if mode == "Inpainting" and canvas_result.image_data is not None:
            valid_input = True
            # Create mask from canvas
            try:
                mask = Image.fromarray(canvas_result.image_data)
                mask = mask.getchannel("A")  # Get alpha channel
                mask_array = np.array(mask)
                mask_array = (mask_array > 0).astype(np.uint8) * 255
                mask = Image.fromarray(mask_array)
                image_for_generation = image
            except Exception as e:
                st.error(f"Error creating mask: {e}")
                return

        elif mode == "Outpainting":
            valid_input = True
            # image_for_generation and mask are already set above

        if not valid_input:
            st.error("Please draw a mask or configure outpainting settings")
            return

        # Create temporary files
        with (
            tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img,
            tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_mask,
        ):
            try:
                image_for_generation.save(tmp_img.name)
                mask.save(tmp_mask.name)
            except Exception as e:
                st.error(f"Error saving temporary files: {e}")
                return

            try:
                # Generate inpainting/outpainting
                rng = torch.Generator(device="cpu")
                if seed is None:
                    seed = rng.seed()

                print(f"Generating with seed {seed}:\n{prompt}")
                t0 = time.perf_counter()

                x = get_noise(
                    1,
                    height,
                    width,
                    device=torch_device,
                    dtype=torch.bfloat16,
                    seed=seed,
                )

                if offload:
                    t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)

                inp = prepare_fill(
                    t5,
                    clip,
                    x,
                    prompt=prompt,
                    ae=ae,
                    img_cond_path=tmp_img.name,
                    mask_path=tmp_mask.name,
                )

                timesteps = get_schedule(num_steps, inp["img"].shape[1], shift=True)

                if offload:
                    t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
                    torch.cuda.empty_cache()
                    model = model.to(torch_device)

                x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)

                if offload:
                    model.cpu()
                    torch.cuda.empty_cache()
                    ae.decoder.to(x.device)

                x = unpack(x.float(), height, width)
                with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
                    x = ae.decode(x)

                t1 = time.perf_counter()
                print(f"Done in {t1 - t0:.1f}s")

                # Process and display result
                x = x.clamp(-1, 1)
                x = embed_watermark(x.float())
                x = rearrange(x[0], "c h w -> h w c")
                img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())

                nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]

                if nsfw_score < NSFW_THRESHOLD:
                    buffer = BytesIO()
                    exif_data = Image.Exif()
                    exif_data[ExifTags.Base.Software] = "AI generated;inpainting;flux"
                    exif_data[ExifTags.Base.Make] = "Black Forest Labs"
                    exif_data[ExifTags.Base.Model] = name
                    if add_sampling_metadata:
                        exif_data[ExifTags.Base.ImageDescription] = prompt
                    img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0)

                    img_bytes = buffer.getvalue()
                    if save_samples:
                        fn = output_name.format(idx=idx)
                        print(f"Saving {fn}")
                        with open(fn, "wb") as file:
                            file.write(img_bytes)

                    if track_usage:
                        track_usage_via_api(name, 1)

                    st.session_state["samples"] = {
                        "prompt": prompt,
                        "img": img,
                        "seed": seed,
                        "bytes": img_bytes,
                    }
                else:
                    st.warning("Your generated image may contain NSFW content.")
                    st.session_state["samples"] = None

            except Exception as e:
                st.error(f"Error during generation: {e}")
                return
            finally:
                # Clean up temporary files
                try:
                    os.unlink(tmp_img.name)
                    os.unlink(tmp_mask.name)
                except Exception as e:
                    print(f"Error cleaning up temporary files: {e}")

    # Display results
    samples = st.session_state.get("samples", None)
    if samples is not None:
        st.image(samples["img"], caption=samples["prompt"])
        col1, col2 = st.columns(2)
        with col1:
            st.download_button(
                "Download full-resolution",
                samples["bytes"],
                file_name="generated.jpg",
                mime="image/jpg",
            )
        with col2:
            if st.button("Continue from this image"):
                # Store the generated image
                new_image = samples["img"]
                # Clear ALL canvas state
                clear_canvas_state()
                if "samples" in st.session_state:
                    del st.session_state["samples"]
                # Set as current image
                st.session_state["current_image"] = new_image
                st.rerun()

        st.write(f"Seed: {samples['seed']}")


def app():
    Fire(main)


if __name__ == "__main__":
    st.set_page_config(layout="wide")
    app()


================================================
FILE: docs/fill.md
================================================
## Open-weight models

FLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless edits that integrate naturally with existing images.

| Name                | HuggingFace repo                                         | License                                                               | sha256sum                                                        |
| ------------------- | -------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Fill [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 03e289f530df51d014f48e675a9ffa2141bc003259bf5f25d75b957e920a41ca |

## Examples

![inpainting](../assets/docs/inpainting.png)
![outpainting](../assets/docs/outpainting.png)

## Open-weights usage

The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
```

For interactive sampling run

```bash
python -m flux fill --loop
```

Or to generate a single sample run

```bash
python -m flux fill \
  --img_cond_path <path_to_input_image> \
  --img_mask_path <path_to_input_mask>
```

The input_mask should be an image of the same size as the conditioning image that only contains black and white pixels; see [an example mask](../assets/cup_mask.png) for [this image](../assets/cup.png).

We also provide an interactive streamlit demo. The demo can be run via

```bash
streamlit run demo_st_fill.py
```


================================================
FILE: docs/image-editing.md
================================================
## Open-weight models

We currently offer two open-weight text-to-image models.

| Name                      | HuggingFace repo                                                | License                                                               | sha256sum                                                        |
| ------------------------- | ----------------------------------------------------------------| --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Kontext [dev]`    | https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev     | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 843a26dc765d3105dba081c30bce7b14c65b0988f9e8d14e9fbc8856a6deebd5 |

## Examples

![FLUX.1 [dev] Grid](../assets/docs/kontext.png)

## Open-weights usage

The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
```

For interactive sampling run

```bash
python -m flux kontext --loop
```
Or to generate a single sample run

```bash
python -m flux kontext \
  --img_cond_path <path_to_input_image> \
  --prompt <your_prompt> \
  --num_steps 30 --aspect_ratio "16:9" --guidance 2.5 --seed 1
```
Note that the flags `num_steps`, `aspect_ratio`, `guidance` and `seed` are
optional. For more available flags see [the code](../src/flux/cli_kontext.py).

### TRT engine infernece

We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).

```bash
python -m flux kontext --loop --trt --trt_transformer_precision <precision>
```
where `<trt_transformer_precision>` is either `bf16`, `fp8`, or `fp4_sdvd32`.


================================================
FILE: docs/image-variation.md
================================================
## Models

FLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which can be used to generate image variations.

| Name                        | HuggingFace repo                                            | License                                                               | sha256sum                                                        |
| --------------------------- | ----------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Redux [dev]`        | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev   | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | a1b3bdcb4bdc58ce04874b9ca776d61fc3e914bb6beab41efb63e4e2694dca45 |


## Examples

![redux](../assets/docs/redux.png)

## Open-weights usage

The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_REDUX=<your redux path here>
export FLUX_AE=<your autoencoder path here>
```

For interactive sampling run:

```bash
python -m flux redux --name <name> --loop
```

where `name` specifies the base model, which should be one of `flux-dev` or `flux-schnell`.


================================================
FILE: docs/structural-conditioning.md
================================================
## Models

Structural conditioning uses canny edge or depth detection to maintain precise control during image transformations. By preserving the original image's structure through edge or depth maps, users can make text-guided edits while keeping the core composition intact. This is particularly effective for retexturing images. We release four variations: two based on edge maps (full model and LoRA for FLUX.1 [dev]) and two based on depth maps (full model and LoRA for FLUX.1 [dev]).

| Name                      | HuggingFace repo                                               | License                                                               | sha256sum                                                        |
| ------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Canny [dev]`      | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev      | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 996876670169591cb412b937fbd46ea14cbed6933aef17c48a2dcd9685c98cdb |
| `FLUX.1 Depth [dev]`      | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev      | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 41360d1662f44ca45bc1b665fe6387e91802f53911001630d970a4f8be8dac21 |
| `FLUX.1 Canny [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 8eaa21b9c43d5e7242844deb64b8cf22ae9010f813f955ca8c05f240b8a98f7e |
| `FLUX.1 Depth [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 1938b38ea0fdd98080fa3e48beb2bedfbc7ad102d8b65e6614de704a46d8b907 |

## Examples

![canny](../assets/docs/canny.png)
![depth](../assets/docs/depth.png)

## Open-weights usage

The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>

# optional (see below)
export FLUX_LORA=<your lora path here>
```

Note that the LoRA models (`flux-dev-canny-lora` and `flux-dev-depth-lora`) require the base FLUX.1 [dev] model to be downloaded first. The system will automatically download both the base model and the LoRA adapter when using these variants.

For interactive sampling run

```bash
python -m flux control --name <name> --loop
```

where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`.

### TRT engine inference

 We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).

```bash
python flux control --name=<name> --loop --img_cond_path="assets/robot.webp" --trt --static_shape=False --trt_transformer_precision <precision>
```
where `<precision>` is either `bf16`, `fp8`, or `fp4`.

## Diffusers usage

Flux Control (including the LoRAs) is also compatible with the `diffusers` Python library. Check out the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.


================================================
FILE: docs/text-to-image.md
================================================
## Open-weight models

We currently offer two open-weight text-to-image models.

| Name                      | HuggingFace repo                                         | License                                                                  | sha256sum                                                        |
| --------------------------|----------------------------------------------------------| -------------------------------------------------------------------------|----------------------------------------------------------------- |
| `FLUX.1 [schnell]`        | https://huggingface.co/black-forest-labs/FLUX.1-schnell  | [apache-2.0](../model_licenses/LICENSE-FLUX1-schnell)                    | 9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a |
| `FLUX.1 [dev]`            | https://huggingface.co/black-forest-labs/FLUX.1-dev      | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 |
| `FLUX.1 Krea [dev]`       | https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 |

## Open-weights usage

The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
```

For interactive sampling run

```bash
python -m flux t2i --name <name> --loop
```

where `name` is one of `flux-dev` or `flux-schnell`. Or to generate a single sample run

```bash
python -m flux t2i --name <name> \
  --height <height> --width <width> \
  --prompt "<prompt>"
```

### TRT engine infernece

We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).

```bash
python -m flux t2i --name=<name> --loop --trt --trt_transformer_precision <precision>
```
where `<trt_transformer_precision>` is either `bf16`, `fp8`, or `fp4`. For ONNX exports, `height` and `width` have to be within 768 and 1344.

### Streamlit and Gradio

We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via

```bash
streamlit run demo_st.py
```

We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:

```bash
python demo_gr.py --name flux-schnell --device cuda
```

Options:

- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
- `--offload`: Offload model to CPU when not in use
- `--share`: Create a public link to your demo

To run the demo with the dev model and create a public link:

```bash
python demo_gr.py --name flux-dev --share
```

## Diffusers integration

`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:

```shell
pip install git+https://github.com/huggingface/diffusers.git
```

Then you can use `FluxPipeline` to run the model

```python
import torch
from diffusers import FluxPipeline

model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
seed = 42
image = pipe(
    prompt,
    output_type="pil",
    num_inference_steps=4, #use a larger number if you are using [dev]
    generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")
```

To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation


================================================
FILE: model_cards/FLUX.1-Krea-dev.md
================================================
![FLUX.1 Krea [dev] Grid](../assets/flux-1-krea-dev-grid.png)

`FLUX.1 Krea [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our [blog post](https://bfl.ai/announcements/flux-1-krea-dev).


# Key Features
1. Cutting-edge output quality, with a focus on aesthetic photography.
2. Competitive prompt following, matching the performance of closed source alternatives.
3. Trained using guidance distillation, making `FLUX.1 Krea [dev]` more efficient.
4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).

# Usage
`FLUX.1 Krea [dev]` can be used as a drop-in replacement in every system that supports the original `FLUX.1 [dev]`.
A reference implementation of `FLUX.1 [dev]` is in our dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point.

`FLUX.1 Krea [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers).

---

# Limitations
- This model is not intended or able to provide factual information.
- As a statistical model this checkpoint might amplify existing societal biases.
- The model may fail to generate output that matches the prompts.
- Prompt following is heavily influenced by the prompting-style.

---

# Out-of-Scope Use
The model and its derivatives may not be used

- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.
- Please reference our [content filters](https://github.com/black-forest-labs/flux/blob/main/src/flux/content_filters.py) to avoid such generations.

---

# Risks 

Black Forest Labs (BFL) and Krea are committed to the responsible development of generative AI technology. Prior to releasing FLUX.1 Krea [dev], BFL and Krea collaboratively evaluated and mitigated a number of risks in the FLUX.1 Krea [dev]  model and services, including the generation of unlawful content. We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks:
1. **Pre-training mitigation.** BFL filtered pre-training data for multiple categories of “not safe for work” (NSFW) and unlawful content to help prevent a user generating unlawful content in response to text prompts or uploaded images.
2. **Post-training mitigation.** BFL has partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known child sexual abuse material (CSAM) from post-training data. Subsequently, BFL and Krea undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse. By inhibiting certain behaviors and concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or nonconsensual intimate imagery (NCII) from a text prompt.
3. **Pre-release evaluation.** Throughout this process, BFL conducted internal and external third-party evaluations of model checkpoints to identify further opportunities for improvement. The third-party evaluations focused on eliciting CSAM and NCII through adversarial testing of the text-to-image model with text-only prompts. We also conducted internal evaluations of the proposed release checkpoints, comparing the model with other leading openly-available generative image models from other companies. The final FLUX.1 Krea [dev] open-weight model checkpoint demonstrated very high resilience against violative inputs, demonstrating higher resilience than other similar open-weight models across these risk categories.  Based on these findings, we approved the release of the FLUX.1 Krea [dev] model as openly-available weights under a non-commercial license to support third-party research and development.
4. **Inference filters.** The BFL Github repository for the open FLUX.1 Krea [dev] model includes filters for illegal or infringing content. Filters or manual review must be used with the model under the terms of the FLUX.1 [dev] Non-Commercial License. We may approach known deployers of the FLUX.1 Krea [dev] model at random to verify that filters or manual review processes are in place.
5. **Policies.** Our FLUX.1 [dev] Non-Commercial License prohibits the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX.1 Krea [dev] model.
6. **Monitoring.** BFL is monitoring for patterns of violative use after release, and may ban developers who we detect intentionally and repeatedly violate our policies. Additionally, BFL provides a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. BFL maintains a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and BFL welcomes ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations.

---

# License
This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).


================================================
FILE: model_cards/FLUX.1-dev.md
================================================
![FLUX.1 [dev] Grid](../assets/dev_grid.jpg)

`FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).

# Key Features
1. Cutting-edge output quality, second only to our state-of-the-art model `FLUX.1 [pro]`.
2. Competitive prompt following, matching the performance of closed source alternatives.
3. Trained using guidance distillation, making `FLUX.1 [dev]` more efficient.
4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](./licence.md).

# Usage
We provide a reference implementation of `FLUX.1 [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point.

## API Endpoints
The FLUX.1 models are also available via API from the following sources
1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
2. [replicate.com](https://replicate.com/collections/flux)
3. [fal.ai](https://fal.ai/models/fal-ai/flux/dev)

## ComfyUI
`FLUX.1 [dev]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.

---
# Limitations
- This model is not intended or able to provide factual information.
- As a statistical model this checkpoint might amplify existing societal biases.
- The model may fail to generate output that matches the prompts.
- Prompt following is heavily influenced by the prompting-style.

# Out-of-Scope Use
The model and its derivatives may not be used

- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.

# License
This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).


================================================
FILE: model_cards/FLUX.1-kontext-dev.md
================================================
![FLUX.1 [dev] Grid](../assets/docs/kontext.png)

`FLUX.1 Kontext [dev]` is a 12 billion parameter rectified flow transformer capable of editing images based on text instructions.
For more information, please read our [blog post](https://bfl.ai/announcements/flux-1-kontext-dev) and our [technical report](https://arxiv.org/abs/2506.15742).
You can find information about the `[pro]` version in [here](https://bfl.ai/models/flux-kontext).

# Key Features
1. Change existing images based on an edit instruction.
2. Have character, style and object reference without any finetuning.
3. Robust consistency allows users to refine an image through multiple successive edits with minimal visual drift.
4. Trained using guidance distillation, making `FLUX.1 Kontext [dev]` more efficient.
5. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
6. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [FLUX.1 \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).

# Usage
We provide a reference implementation of `FLUX.1 Kontext [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 Kontext [dev]` are encouraged to use this as a starting point.

`FLUX.1 Kontext [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers).

## API Endpoints
The FLUX.1 Kontext models are also available via API from the following sources
- bfl.ai: https://docs.bfl.ai/
- DataCrunch: https://datacrunch.io/flux-kontext
- fal: https://fal.ai/flux-kontext
- Replicate: https://replicate.com/blog/flux-kontext
    - https://replicate.com/black-forest-labs/flux-kontext-dev
    - https://replicate.com/black-forest-labs/flux-kontext-pro
    - https://replicate.com/black-forest-labs/flux-kontext-max
- Runware: https://runware.ai/blog/introducing-flux1-kontext-instruction-based-image-editing-with-ai?utm_source=bfl
- TogetherAI: https://www.together.ai/models/flux-1-kontext-dev

---

# Risks

Risks
Black Forest Labs is committed to the responsible development of generative AI technology. Prior to releasing FLUX.1 Kontext, we evaluated and mitigated a number of risks in our models and services, including the generation of unlawful content. We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks:
1. **Pre-training mitigation**. We filtered pre-training data for multiple categories of “not safe for work” (NSFW) content to help prevent a user generating unlawful content in response to text prompts or uploaded images.
2. **Post-training mitigation.** We have partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known child sexual abuse material (CSAM) from post-training data. Subsequently, we undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse. By inhibiting certain behaviors and concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or nonconsensual intimate imagery (NCII) from a text prompt, or transforming an uploaded image into synthetic CSAM or NCII.
3. **Pre-release evaluation.** Throughout this process, we conducted multiple internal and external third-party evaluations of model checkpoints to identify further opportunities for improvement. The third-party evaluations—which included 21 checkpoints of FLUX.1 Kontext [pro] and [dev]—focused on eliciting CSAM and NCII through adversarial testing with text-only prompts, as well as uploaded images with text prompts. Next, we conducted a final third-party evaluation of the proposed release checkpoints, focused on text-to-image and image-to-image CSAM and NCII generation. The final FLUX.1 Kontext [pro] (as offered through the FLUX API only) and FLUX.1 Kontext [dev] (released as an open-weight model) checkpoints demonstrated very high resilience against violative inputs, and FLUX.1 Kontext [dev] demonstrated higher resilience than other similar open-weight models across these risk categories.  Based on these findings, we approved the release of the FLUX.1 Kontext [pro] model via API, and the release of the FLUX.1 Kontext [dev] model as openly-available weights under a non-commercial license to support third-party research and development.
4. **Inference filters.** We are applying multiple filters to intercept text prompts, uploaded images, and output images on the FLUX API for FLUX.1 Kontext [pro]. Filters for CSAM and NCII are provided by Hive, a third-party provider, and cannot be adjusted or removed by developers. We provide filters for other categories of potentially harmful content, including gore, which can be adjusted by developers based on their specific risk profile. Additionally, the repository for the open FLUX.1 Kontext [dev] model includes filters for illegal or infringing content. Filters or manual review must be used with the model under the terms of the FLUX.1 [dev] Non-Commercial License. We may approach known deployers of the FLUX.1 Kontext [dev] model at random to verify that filters or manual review processes are in place.
5. **Content provenance.** The FLUX API applies cryptographically-signed metadata to output content to indicate that images were produced with our model. Our API implements the Coalition for Content Provenance and Authenticity (C2PA) standard for metadata.
6. **Policies.** Access to our API and use of our models are governed by our Developer Terms of Service, Usage Policy, and FLUX.1 [dev] Non-Commercial License, which prohibit the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX Kontext models.
7. **Monitoring.** We are monitoring for patterns of violative use after release, and may ban developers who we detect intentionally and repeatedly violate our policies via the FLUX API. Additionally, we provide a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. We maintain a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and we welcome ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations.


# License
This model falls under the [FLUX.1 \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).


# Citation

```bib
@misc{labs2025flux1kontextflowmatching,
      title={FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space}, Add commentMore actions
      author={Black Forest Labs and Stephen Batifol and Andreas Blattmann and Frederic Boesel and Saksham Consul and Cyril Diagne and Tim Dockhorn and Jack English and Zion English and Patrick Esser and Sumith Kulal and Kyle Lacey and Yam Levi and Cheng Li and Dominik Lorenz and Jonas Müller and Dustin Podell and Robin Rombach and Harry Saini and Axel Sauer and Luke Smith},
      year={2025},
      eprint={2506.15742},
      archivePrefix={arXiv},
      primaryClass={cs.GR},
      url={https://arxiv.org/abs/2506.15742},
}
```


================================================
FILE: model_cards/FLUX.1-schnell.md
================================================
![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg)

`FLUX.1 [schnell]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).

# Key Features
1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
2. Trained using latent adversarial diffusion distillation, `FLUX.1 [schnell]` can generate high-quality images in only 1 to 4 steps.
3. Released under the `apache-2.0` licence, the model can be used for personal, scientific, and commercial purposes.

# Usage
We provide a reference implementation of `FLUX.1 [schnell]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 [schnell]` are encouraged to use this as a starting point.

## API Endpoints
The FLUX.1 models are also available via API from the following sources
1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
2. [replicate.com](https://replicate.com/collections/flux)
3. [fal.ai](https://fal.ai/models/fal-ai/flux/schnell)

## ComfyUI
`FLUX.1 [schnell]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.

---
# Limitations
- This model is not intended or able to provide factual information.
- As a statistical model this checkpoint might amplify existing societal biases.
- The model may fail to generate output that matches the prompts.
- Prompt following is heavily influenced by the prompting-style.

# Out-of-Scope Use
The model and its derivatives may not be used

- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.


================================================
FILE: model_licenses/LICENSE-FLUX1-dev
================================================
FLUX.1 [dev] Non-Commercial License v1.1.1

Black Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”).  The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models and models denoted as FLUX.1 [dev], including but not limited to FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA, FLUX.1 Depth [dev] LoRA, and FLUX.1 Kontext [dev], and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). Note that we may also make available certain elements of what is included in the definition of “FLUX.1 [dev] Model” under a separate license, such as the inference code, and nothing in this License will be deemed to restrict or limit any other licenses granted by us in such elements.

By downloading, accessing, using, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.

1. Definitions.
- a. “Derivative”  means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
- b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
- c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the FLUX.1 [dev] Model, Derivatives, or FLUX Content Filters (as defined below): (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment; and (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use (a) for revenue-generating activity, (b) in direct interactions with or that has impact on end users, or (c) to train, fine tune or distill other models for commercial use, in each case is not a Non-Commercial Purpose.
- d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from an input (such as an image input) or prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of the FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
- e.   “you” or “your” means the individual or entity entering into this License with Company.

2. License Grant.
- a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models and Derivatives solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License.  Any restrictions set forth herein regarding the FLUX.1 [dev] Model also apply to any Derivative you create or that are created on your behalf.
- b. Non-Commercial Use Only.  You may only access, use, Distribute, or create Derivatives of the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes.  If you want to use a FLUX.1 [dev] Model or a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please see www.bfl.ai if you would like a commercial license.
- c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
- d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein.  You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model or the FLUX.1 Kontext [dev] Model.
- e. You may access, use, Distribute, or create Output of the FLUX.1 [dev] Model or Derivatives if you: (i) (A) implement and maintain content filtering measures (“Content Filters”) for your use of the FLUX.1 [dev] Model or Derivatives to prevent the creation, display, transmission, generation, or dissemination of unlawful or infringing content, which may include Content Filters that we may make available for use with the FLUX.1 [dev] Model (“FLUX Content Filters”), or (B) ensure Output undergoes review for unlawful or infringing content before public or non-public distribution, display, transmission or dissemination; and (ii) ensure Output includes disclosure (or other indication) that the Output was generated or modified using artificial intelligence technologies to the extent required under applicable law.

3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
- a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
- b. you must prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):

    “The FLUX.1 [dev] Model is licensed by Black Forest Labs Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs Inc.
    IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”

- c. in the case of Distribution of Derivatives made by you: (i) you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; (ii) any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions and must include disclaimer of warranties and limitation of liability provisions that are at least as protective of Company as those set forth herein; and (iii) you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.

4. Restrictions.  You will not, and will not permit, assist or cause any third party to
- a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, (i) for any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates (or is likely to infringe, misappropriate, or otherwise violate) any third party’s legal rights, including rights of publicity or “digital replica” rights, (vi) in any unlawful, fraudulent, defamatory, or abusive activity, (vii) to generate unlawful content, including child sexual abuse material, or non-consensual intimate images;  or (viii) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, any and all laws governing the processing of biometric information, and the EU Artificial Intelligence Act (Regulation (EU) 2024/1689), as well as all amendments and successor laws to any of the foregoing;
- b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
- c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model;
- d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License;
- e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
- f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model  (i) to any individual, entity, or country prohibited by Export Laws; (ii) to anyone on U.S. or non-U.S. government restricted parties lists; (iii) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; (iv) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (v) will not disguise your location through IP proxying or other methods.

5. DISCLAIMERS.  THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.

6. LIMITATION OF LIABILITY.  TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, FLUX CONTENT FILTERS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.

7. INDEMNIFICATION. You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to  (a) your access to or use of the FLUX.1 [dev] Model (including in connection with any Output, results or data generated from such access or use, or from your access or use of any FLUX Content Filters), including any High-Risk Use; (b) your Content Filters, including your failure to implement any Content Filters where required by this License such as in Section 2(e); (c)  your violation of this License; or (d) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.

8. Termination; Survival.
- a. This License will automatically terminate upon any breach by you of the terms of this License.
- b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
- c. If you initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model, any Derivative, or FLUX Content Filters, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
- d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model, any Derivatives, and any FLUX Content Filters.  The following sections survive termination of this License  2(c), 2(d), 4-11.

9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.

10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name, logo or trademark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.

11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter.


================================================
FILE: model_licenses/LICENSE-FLUX1-schnell
================================================


Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION

1. Definitions.

"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.

"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.

"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.

"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.

"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.

"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.

"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).

"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.

"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."

"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.

2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.

3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.

4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:

    You must give any other recipients of the Work or Derivative Works a copy of this License; and
    You must cause any modified files to carry prominent notices stating that You changed the files; and
    You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
    If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.

You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.

5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.

6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.

7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.

8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.

9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.

END OF TERMS AND CONDITIONS


================================================
FILE: pyproject.toml
================================================
[project]
name = "flux"
authors = [
  { name = "Black Forest Labs", email = "support@blackforestlabs.ai" },
]
description = "Inference codebase for FLUX"
readme = "README.md"
requires-python = ">=3.10"
license = { file = "LICENSE.md" }
dynamic = ["version"]
dependencies = [
  "accelerate",
  "einops",
  "fire >= 0.6.0",
  "huggingface-hub",
  "safetensors",
  "sentencepiece",
  "transformers >= 4.45.2",
  "tokenizers",
  "protobuf",
  "requests",
  "invisible-watermark",
  "ruff == 0.6.8",
  "accelerate",
]

[project.optional-dependencies]
torch = [
  "torch == 2.6.0",
  "torchvision",
]
streamlit = [
  "streamlit",
  "streamlit-drawable-canvas",
  "streamlit-keyup",
]
gradio = [
  "gradio",
]
tensorrt = [
  "tensorrt-cu12 == 10.12.0.36",
  "colored",
  "opencv-python-headless==4.8.0.74",
  "onnx >=1.18.0",
  "onnxruntime ~= 1.22.0",
  "onnxruntime-gpu ~= 1.22.0",
  "onnx-graphsurgeon",
  "polygraphy >= 0.49.22",
]
all = [
  "flux[gradio]",
  "flux[streamlit]",
  "flux[torch]",
]

[project.scripts]
flux = "flux.cli:app"

[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]

[tool.ruff]
line-length = 110
target-version = "py310"
extend-exclude = ["/usr/lib/*"]

[tool.ruff.lint]
ignore = [
  "E501", # line too long - will be fixed in format
]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
line-ending = "auto"
skip-magic-trailing-comma = false
docstring-code-format = true
exclude = [
  "src/flux/_version.py", # generated by setuptools_scm
]

[tool.ruff.lint.isort]
combine-as-imports = true
force-wrap-aliases = true
known-local-folder = ["src"]
known-first-party = ["flux"]

[tool.pyright]
include = ["src"]
exclude = [
  "**/__pycache__", # cache directories
  "./typings",      # generated type stubs
]
stubPath = "./typings"

[tool.tomlsort]
in_place = true
no_sort_tables = true
spaces_before_inline_comment = 1
spaces_indent_inline_array = 2
trailing_comma_inline_array = true
sort_first = [
  "project",
  "build-system",
  "tool.setuptools",
]

# needs to be last for CI reasons
[tool.setuptools_scm]
write_to = "src/flux/_version.py"
parentdir_prefix_version = "flux-"
fallback_version = "0.0.0"
version_scheme = "post-release"


================================================
FILE: setup.py
================================================
import setuptools

setuptools.setup()


================================================
FILE: src/flux/__init__.py
================================================
try:
    from ._version import (
        version as __version__,  # type: ignore
        version_tuple,
    )
except ImportError:
    __version__ = "unknown (no version information available)"
    version_tuple = (0, 0, "unknown", "noinfo")

from pathlib import Path

PACKAGE = __package__.replace("_", "-")
PACKAGE_ROOT = Path(__file__).parent


================================================
FILE: src/flux/__main__.py
================================================
from fire import Fire

from .cli import main as cli_main
from .cli_control import main as control_main
from .cli_fill import main as fill_main
from .cli_kontext import main as kontext_main
from .cli_redux import main as redux_main

if __name__ == "__main__":
    Fire(
        {
            "t2i": cli_main,
            "control": control_main,
            "fill": fill_main,
            "kontext": kontext_main,
            "redux": redux_main,
        }
    )


================================================
FILE: src/flux/cli.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob

import torch
from fire import Fire
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
    check_onnx_access_for_trt,
    configs,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    save_image,
)

NSFW_THRESHOLD = 0.85


@dataclass
class SamplingOptions:
    prompt: str
    width: int
    height: int
    num_steps: int
    guidance: float
    seed: int | None


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the prompt or write a command starting with a slash:\n"
        "- '/w <width>' will set the width of the generated image\n"
        "- '/h <height>' will set the height of the generated image\n"
        "- '/s <seed>' sets the next seed\n"
        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
        "- '/n <steps>' sets the number of steps\n"
        "- '/q' to quit"
    )

    while (prompt := input(user_question)).startswith("/"):
        if prompt.startswith("/w"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, width = prompt.split()
            options.width = 16 * (int(width) // 16)
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height * options.width / 1e6:.2f}MP)"
            )
        elif prompt.startswith("/h"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, height = prompt.split()
            options.height = 16 * (int(height) // 16)
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height * options.width / 1e6:.2f}MP)"
            )
        elif prompt.startswith("/g"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, guidance = prompt.split()
            options.guidance = float(guidance)
            print(f"Setting guidance to {options.guidance}")
        elif prompt.startswith("/s"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, seed = prompt.split()
            options.seed = int(seed)
            print(f"Setting seed to {options.seed}")
        elif prompt.startswith("/n"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, steps = prompt.split()
            options.num_steps = int(steps)
            print(f"Setting number of steps to {options.num_steps}")
        elif prompt.startswith("/q"):
            print("Quitting")
            return None
        else:
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    if prompt != "":
        options.prompt = prompt
    return options


@torch.inference_mode()
def main(
    name: str = "flux-dev-krea",
    width: int = 1360,
    height: int = 768,
    seed: int | None = None,
    prompt: str = (
        "a photo of a forest with mist swirling around the tree trunks. The word "
        '"FLUX" is painted over it in big, red brush strokes with visible texture'
    ),
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    num_steps: int | None = None,
    loop: bool = False,
    guidance: float = 2.5,
    offload: bool = False,
    output_dir: str = "output",
    add_sampling_metadata: bool = True,
    trt: bool = False,
    trt_transformer_precision: str = "bf16",
    track_usage: bool = False,
):
    """
    Sample the flux model. Either interactively (set `--loop`) or run for a
    single image.

    Args:
        name: Name of the model to load
        height: height of the sample in pixels (should be a multiple of 16)
        width: width of the sample in pixels (should be a multiple of 16)
        seed: Set a seed for sampling
        output_name: where to save the output image, `{idx}` will be replaced
            by the index of the sample
        prompt: Prompt used for sampling
        device: Pytorch device
        num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
        loop: start an interactive session and sample multiple times
        guidance: guidance value used for guidance distillation
        add_sampling_metadata: Add the prompt to the image Exif metadata
        trt: use TensorRT backend for optimized inference
        trt_transformer_precision: specify transformer precision for inference
        track_usage: track usage of the model for licensing purposes
    """

    prompt = prompt.split("|")
    if len(prompt) == 1:
        prompt = prompt[0]
        additional_prompts = None
    else:
        additional_prompts = prompt[1:]
        prompt = prompt[0]

    assert not (
        (additional_prompts is not None) and loop
    ), "Do not provide additional prompts and set loop to True"

    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

    if name not in configs:
        available = ", ".join(configs.keys())
        raise ValueError(f"Got unknown model name: {name}, chose from {available}")

    torch_device = torch.device(device)
    if num_steps is None:
        num_steps = 4 if name == "flux-schnell" else 50

    # allow for packing and conversion to latent space
    height = 16 * (height // 16)
    width = 16 * (width // 16)

    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    if not trt:
        t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
        clip = load_clip(torch_device)
        model = load_flow_model(name, device="cpu" if offload else torch_device)
        ae = load_ae(name, device="cpu" if offload else torch_device)
    else:
        # lazy import to make install optional
        from flux.trt.trt_manager import ModuleName, TRTManager

        # Check if we need ONNX model access (which requires authentication for FLUX models)
        onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)

        trt_ctx_manager = TRTManager(
            trt_transformer_precision=trt_transformer_precision,
            trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"),
        )
        engines = trt_ctx_manager.load_engines(
            model_name=name,
            module_names={
                ModuleName.CLIP,
                ModuleName.TRANSFORMER,
                ModuleName.T5,
                ModuleName.VAE,
            },
            engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
            custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
            trt_image_height=height,
            trt_image_width=width,
            trt_batch_size=1,
            trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
            trt_static_batch=False,
            trt_static_shape=False,
        )

        ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
        model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
        clip = engines[ModuleName.CLIP].to(torch_device)
        t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)

    rng = torch.Generator(device="cpu")
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
    )

    if loop:
        opts = parse_prompt(opts)

    while opts is not None:
        if opts.seed is None:
            opts.seed = rng.seed()
        print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
        t0 = time.perf_counter()

        # prepare input
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=torch_device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        opts.seed = None
        if offload:
            ae = ae.cpu()
            torch.cuda.empty_cache()
            t5, clip = t5.to(torch_device), clip.to(torch_device)
        inp = prepare(t5, clip, x, prompt=opts.prompt)
        timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

        # offload TEs to CPU, load model to gpu
        if offload:
            t5, clip = t5.cpu(), clip.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # denoise initial noise
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # offload model, load autoencoder to gpu
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.perf_counter()

        fn = output_name.format(idx=idx)
        print(f"Done in {t1 - t0:.1f}s. Saving {fn}")

        idx = save_image(
            nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
        )

        if loop:
            print("-" * 80)
            opts = parse_prompt(opts)
        elif additional_prompts:
            next_prompt = additional_prompts.pop(0)
            opts.prompt = next_prompt
        else:
            opts = None

    if trt:
        trt_ctx_manager.stop_runtime()


if __name__ == "__main__":
    Fire(main)


================================================
FILE: src/flux/cli_control.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob

import torch
from fire import Fire
from transformers import pipeline

from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image


@dataclass
class SamplingOptions:
    prompt: str
    width: int
    height: int
    num_steps: int
    guidance: float
    seed: int | None
    img_cond_path: str
    lora_scale: float | None


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the prompt or write a command starting with a slash:\n"
        "- '/w <width>' will set the width of the generated image\n"
        "- '/h <height>' will set the height of the generated image\n"
        "- '/s <seed>' sets the next seed\n"
        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
        "- '/n <steps>' sets the number of steps\n"
        "- '/q' to quit"
    )

    while (prompt := input(user_question)).startswith("/"):
        if prompt.startswith("/w"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, width = prompt.split()
            options.width = 16 * (int(width) // 16)
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height * options.width / 1e6:.2f}MP)"
            )
        elif prompt.startswith("/h"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, height = prompt.split()
            options.height = 16 * (int(height) // 16)
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height * options.width / 1e6:.2f}MP)"
            )
        elif prompt.startswith("/g"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, guidance = prompt.split()
            options.guidance = float(guidance)
            print(f"Setting guidance to {options.guidance}")
        elif prompt.startswith("/s"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, seed = prompt.split()
            options.seed = int(seed)
            print(f"Setting seed to {options.seed}")
        elif prompt.startswith("/n"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, steps = prompt.split()
            options.num_steps = int(steps)
            print(f"Setting number of steps to {options.num_steps}")
        elif prompt.startswith("/q"):
            print("Quitting")
            return None
        else:
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    if prompt != "":
        options.prompt = prompt
    return options


def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
    if options is None:
        return None

    user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the conditioning image or write a command starting with a slash:\n"
        "- '/q' to quit"
    )

    while True:
        img_cond_path = input(user_question)

        if img_cond_path.startswith("/"):
            if img_cond_path.startswith("/q"):
                print("Quitting")
                return None
            else:
                if not img_cond_path.startswith("/h"):
                    print(f"Got invalid command '{img_cond_path}'\n{usage}")
                print(usage)
            continue

        if img_cond_path == "":
            break

        if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
            (".jpg", ".jpeg", ".png", ".webp")
        ):
            print(f"File '{img_cond_path}' does not exist or is not a valid image file")
            continue

        options.img_cond_path = img_cond_path
        break

    return options


def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
    changed = False

    if options is None:
        return None, changed

    user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the lora scale or write a command starting with a slash:\n"
        "- '/q' to quit"
    )

    while (prompt := input(user_question)).startswith("/"):
        if prompt.startswith("/q"):
            print("Quitting")
            return None, changed
        else:
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    if prompt != "":
        options.lora_scale = float(prompt)
        changed = True
    return options, changed


@torch.inference_mode()
def main(
    name: str,
    width: int = 1024,
    height: int = 1024,
    seed: int | None = None,
    prompt: str = "a robot made out of gold",
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    num_steps: int = 50,
    loop: bool = False,
    guidance: float | None = None,
    offload: bool = False,
    output_dir: str = "output",
    add_sampling_metadata: bool = True,
    img_cond_path: str = "assets/robot.webp",
    lora_scale: float | None = 0.85,
    trt: bool = False,
    trt_transformer_precision: str = "bf16",
    track_usage: bool = False,
    **kwargs: dict | None,
):
    """
    Sample the flux model. Either interactively (set `--loop`) or run for a
    single image.

    Args:
        height: height of the sample in pixels (should be a multiple of 16)
        width: width of the sample in pixels (should be a multiple of 16)
        seed: Set a seed for sampling
        output_name: where to save the output image, `{idx}` will be replaced
            by the index of the sample
        prompt: Prompt used for sampling
        device: Pytorch device
        num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
        loop: start an interactive session and sample multiple times
        guidance: guidance value used for guidance distillation
        add_sampling_metadata: Add the prompt to the image Exif metadata
        img_cond_path: path to conditioning image (jpeg/png/webp)
        trt: use TensorRT backend for optimized inference
        trt_transformer_precision: specify transformer precision for inference
        track_usage: track usage of the model for licensing purposes
    """
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

    if "lora" in name:
        assert not trt, "TRT does not support LORA"
    assert name in [
        "flux-dev-canny",
        "flux-dev-depth",
        "flux-dev-canny-lora",
        "flux-dev-depth-lora",
    ], f"Got unknown model name: {name}"

    if guidance is None:
        if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
            guidance = 30.0
        elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
            guidance = 10.0
        else:
            raise NotImplementedError()

    if name not in configs:
        available = ", ".join(configs.keys())
        raise ValueError(f"Got unknown model name: {name}, chose from {available}")

    torch_device = torch.device(device)

    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
        img_embedder = DepthImageEncoder(torch_device)
    elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
        img_embedder = CannyImageEncoder(torch_device)
    else:
        raise NotImplementedError()

    if not trt:
        # init all components
        t5 = load_t5(torch_device, max_length=512)
        clip = load_clip(torch_device)
        model = load_flow_model(name, device="cpu" if offload else torch_device)
        ae = load_ae(name, device="cpu" if offload else torch_device)
    else:
        # lazy import to make install optional
        from flux.trt.trt_manager import ModuleName, TRTManager

        trt_ctx_manager = TRTManager(
            trt_transformer_precision=trt_transformer_precision,
            trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
        )

        engines = trt_ctx_manager.load_engines(
            model_name=name,
            module_names={
                ModuleName.CLIP,
                ModuleName.TRANSFORMER,
                ModuleName.T5,
                ModuleName.VAE,
                ModuleName.VAE_ENCODER,
            },
            engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
            custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""),
            trt_image_height=height,
            trt_image_width=width,
            trt_batch_size=1,
            trt_static_batch=kwargs.get("static_batch", True),
            trt_static_shape=kwargs.get("static_shape", True),
        )

        ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
        model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
        clip = engines[ModuleName.CLIP].to(torch_device)
        t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)

    # set lora scale
    if "lora" in name and lora_scale is not None:
        for _, module in model.named_modules():
            if hasattr(module, "set_scale"):
                module.set_scale(lora_scale)

    rng = torch.Generator(device="cpu")
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
        img_cond_path=img_cond_path,
        lora_scale=lora_scale,
    )

    if loop:
        opts = parse_prompt(opts)
        opts = parse_img_cond_path(opts)
        if "lora" in name:
            opts, changed = parse_lora_scale(opts)
            if changed:
                # update the lora scale:
                for _, module in model.named_modules():
                    if hasattr(module, "set_scale"):
                        module.set_scale(opts.lora_scale)

    while opts is not None:
        if opts.seed is None:
            opts.seed = rng.seed()
        print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
        t0 = time.perf_counter()

        # prepare input
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=torch_device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        opts.seed = None
        if offload:
            t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
        inp = prepare_control(
            t5,
            clip,
            x,
            prompt=opts.prompt,
            ae=ae,
            encoder=img_embedder,
            img_cond_path=opts.img_cond_path,
        )
        timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

        # offload TEs and AE to CPU, load model to gpu
        if offload:
            t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # denoise initial noise
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # offload model, load autoencoder to gpu
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.perf_counter()
        print(f"Done in {t1 - t0:.1f}s")

        idx = save_image(
            nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
        )

        if loop:
            print("-" * 80)
            opts = parse_prompt(opts)
            opts = parse_img_cond_path(opts)
            if "lora" in name:
                opts, changed = parse_lora_scale(opts)
                if changed:
                    # update the lora scale:
                    for _, module in model.named_modules():
                        if hasattr(module, "set_scale"):
                            module.set_scale(opts.lora_scale)
        else:
            opts = None

    if trt:
        trt_ctx_manager.stop_runtime()


if __name__ == "__main__":
    Fire(main)


================================================
FILE: src/flux/cli_fill.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob

import torch
from fire import Fire
from PIL import Image
from transformers import pipeline

from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image


@dataclass
class SamplingOptions:
    prompt: str
    width: int
    height: int
    num_steps: int
    guidance: float
    seed: int | None
    img_cond_path: str
    img_mask_path: str


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the prompt or write a command starting with a slash:\n"
        "- '/s <seed>' sets the next seed\n"
        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
        "- '/n <steps>' sets the number of steps\n"
        "- '/q' to quit"
    )

    while (prompt := input(user_question)).startswith("/"):
        if prompt.startswith("/g"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, guidance = prompt.split()
            options.guidance = float(guidance)
            print(f"Setting guidance to {options.guidance}")
        elif prompt.startswith("/s"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, seed = prompt.split()
            options.seed = int(seed)
            print(f"Setting seed to {options.seed}")
        elif prompt.startswith("/n"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, steps = prompt.split()
            options.num_steps = int(steps)
            print(f"Setting number of steps to {options.num_steps}")
        elif prompt.startswith("/q"):
            print("Quitting")
            return None
        else:
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    if prompt != "":
        options.prompt = prompt
    return options


def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
    if options is None:
        return None

    user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the conditioning image or write a command starting with a slash:\n"
        "- '/q' to quit"
    )

    while True:
        img_cond_path = input(user_question)

        if img_cond_path.startswith("/"):
            if img_cond_path.startswith("/q"):
                print("Quitting")
                return None
            else:
                if not img_cond_path.startswith("/h"):
                    print(f"Got invalid command '{img_cond_path}'\n{usage}")
                print(usage)
            continue

        if img_cond_path == "":
            break

        if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
            (".jpg", ".jpeg", ".png", ".webp")
        ):
            print(f"File '{img_cond_path}' does not exist or is not a valid image file")
            continue
        else:
            with Image.open(img_cond_path) as img:
                width, height = img.size

            if width % 32 != 0 or height % 32 != 0:
                print(f"Image dimensions must be divisible by 32, got {width}x{height}")
                continue

        options.img_cond_path = img_cond_path
        break

    return options


def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
    if options is None:
        return None

    user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the conditioning mask or write a command starting with a slash:\n"
        "- '/q' to quit"
    )

    while True:
        img_mask_path = input(user_question)

        if img_mask_path.startswith("/"):
            if img_mask_path.startswith("/q"):
                print("Quitting")
                return None
            else:
                if not img_mask_path.startswith("/h"):
                    print(f"Got invalid command '{img_mask_path}'\n{usage}")
                print(usage)
            continue

        if img_mask_path == "":
            break

        if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
            (".jpg", ".jpeg", ".png", ".webp")
        ):
            print(f"File '{img_mask_path}' does not exist or is not a valid image file")
            continue
        else:
            with Image.open(img_mask_path) as img:
                width, height = img.size

            if width % 32 != 0 or height % 32 != 0:
                print(f"Image dimensions must be divisible by 32, got {width}x{height}")
                continue
            else:
                with Image.open(options.img_cond_path) as img_cond:
                    img_cond_width, img_cond_height = img_cond.size

                if width != img_cond_width or height != img_cond_height:
                    print(
                        f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
                    )
                    continue

        options.img_mask_path = img_mask_path
        break

    return options


@torch.inference_mode()
def main(
    seed: int | None = None,
    prompt: str = "a white paper cup",
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    num_steps: int = 50,
    loop: bool = False,
    guidance: float = 30.0,
    offload: bool = False,
    output_dir: str = "output",
    add_sampling_metadata: bool = True,
    img_cond_path: str = "assets/cup.png",
    img_mask_path: str = "assets/cup_mask.png",
    track_usage: bool = False,
):
    """
    Sample the flux model. Either interactively (set `--loop`) or run for a
    single image. This demo assumes that the conditioning image and mask have
    the same shape and that height and width are divisible by 32.

    Args:
        seed: Set a seed for sampling
        output_name: where to save the output image, `{idx}` will be replaced
            by the index of the sample
        prompt: Prompt used for sampling
        device: Pytorch device
        num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
        loop: start an interactive session and sample multiple times
        guidance: guidance value used for guidance distillation
        add_sampling_metadata: Add the prompt to the image Exif metadata
        img_cond_path: path to conditioning image (jpeg/png/webp)
        img_mask_path: path to conditioning mask (jpeg/png/webp)
        track_usage: track usage of the model for licensing purposes
    """
    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

    name = "flux-dev-fill"
    if name not in configs:
        available = ", ".join(configs.keys())
        raise ValueError(f"Got unknown model name: {name}, chose from {available}")

    torch_device = torch.device(device)

    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    # init all components
    t5 = load_t5(torch_device, max_length=128)
    clip = load_clip(torch_device)
    model = load_flow_model(name, device="cpu" if offload else torch_device)
    ae = load_ae(name, device="cpu" if offload else torch_device)

    rng = torch.Generator(device="cpu")
    with Image.open(img_cond_path) as img:
        width, height = img.size
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
        img_cond_path=img_cond_path,
        img_mask_path=img_mask_path,
    )

    if loop:
        opts = parse_prompt(opts)
        opts = parse_img_cond_path(opts)

        with Image.open(opts.img_cond_path) as img:
            width, height = img.size
        opts.height = height
        opts.width = width

        opts = parse_img_mask_path(opts)

    while opts is not None:
        if opts.seed is None:
            opts.seed = rng.seed()
        print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
        t0 = time.perf_counter()

        # prepare input
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=torch_device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        opts.seed = None
        if offload:
            t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
        inp = prepare_fill(
            t5,
            clip,
            x,
            prompt=opts.prompt,
            ae=ae,
            img_cond_path=opts.img_cond_path,
            mask_path=opts.img_mask_path,
        )

        timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

        # offload TEs and AE to CPU, load model to gpu
        if offload:
            t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # denoise initial noise
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # offload model, load autoencoder to gpu
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.perf_counter()
        print(f"Done in {t1 - t0:.1f}s")

        idx = save_image(
            nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
        )

        if loop:
            print("-" * 80)
            opts = parse_prompt(opts)
            opts = parse_img_cond_path(opts)

            with Image.open(opts.img_cond_path) as img:
                width, height = img.size
            opts.height = height
            opts.width = width

            opts = parse_img_mask_path(opts)
        else:
            opts = None


if __name__ == "__main__":
    Fire(main)


================================================
FILE: src/flux/cli_kontext.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob

import torch
from fire import Fire

from flux.content_filters import PixtralContentFilter
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
from flux.util import (
    aspect_ratio_to_height_width,
    check_onnx_access_for_trt,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    save_image,
)


@dataclass
class SamplingOptions:
    prompt: str
    width: int | None
    height: int | None
    num_steps: int
    guidance: float
    seed: int | None
    img_cond_path: str


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
    user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the prompt or write a command starting with a slash:\n"
        "- '/ar <width>:<height>' will set the aspect ratio of the generated image\n"
        "- '/s <seed>' sets the next seed\n"
        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
        "- '/n <steps>' sets the number of steps\n"
        "- '/q' to quit"
    )

    while (prompt := input(user_question)).startswith("/"):
        if prompt.startswith("/ar"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, ratio_prompt = prompt.split()
            if ratio_prompt == "auto":
                options.width = None
                options.height = None
                print("Setting resolution to input image resolution.")
            else:
                options.width, options.height = aspect_ratio_to_height_width(ratio_prompt)
                print(f"Setting resolution to {options.width} x {options.height}.")
        elif prompt.startswith("/h"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, height = prompt.split()
            if height == "auto":
                options.height = None
            else:
                options.height = 16 * (int(height) // 16)
            if options.height is not None and options.width is not None:
                print(
                    f"Setting resolution to {options.width} x {options.height} "
                    f"({options.height * options.width / 1e6:.2f}MP)"
                )
            else:
                print(f"Setting resolution to {options.width} x {options.height}.")
        elif prompt.startswith("/g"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, guidance = prompt.split()
            options.guidance = float(guidance)
            print(f"Setting guidance to {options.guidance}")
        elif prompt.startswith("/s"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, seed = prompt.split()
            options.seed = int(seed)
            print(f"Setting seed to {options.seed}")
        elif prompt.startswith("/n"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, steps = prompt.split()
            options.num_steps = int(steps)
            print(f"Setting number of steps to {options.num_steps}")
        elif prompt.startswith("/q"):
            print("Quitting")
            return None
        else:
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    if prompt != "":
        options.prompt = prompt
    return options


def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
    if options is None:
        return None

    user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write a path to an image directly, leave this field empty "
        "to repeat the last input image or write a command starting with a slash:\n"
        "- '/q' to quit\n\n"
        "The input image will be edited by FLUX.1 Kontext creating a new image based"
        "on your instruction prompt."
    )

    while True:
        img_cond_path = input(user_question)

        if img_cond_path.startswith("/"):
            if img_cond_path.startswith("/q"):
                print("Quitting")
                return None
            else:
                if not img_cond_path.startswith("/h"):
                    print(f"Got invalid command '{img_cond_path}'\n{usage}")
                print(usage)
            continue

        if img_cond_path == "":
            break

        if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
            (".jpg", ".jpeg", ".png", ".webp")
        ):
            print(f"File '{img_cond_path}' does not exist or is not a valid image file")
            continue

        options.img_cond_path = img_cond_path
        break

    return options


@torch.inference_mode()
def main(
    name: str = "flux-dev-kontext",
    aspect_ratio: str | None = None,
    seed: int | None = None,
    prompt: str = "replace the logo with the text 'Black Forest Labs'",
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    num_steps: int = 30,
    loop: bool = False,
    guidance: float = 2.5,
    offload: bool = False,
    output_dir: str = "output",
    add_sampling_metadata: bool = True,
    img_cond_path: str = "assets/cup.png",
    trt: bool = False,
    trt_transformer_precision: str = "bf16",
    track_usage: bool = False,
):
    """
    Sample the flux model. Either interactively (set `--loop`) or run for a
    single image.

    Args:
        height: height of the sample in pixels (should be a multiple of 16), None
            defaults to the size of the conditioning
        width: width of the sample in pixels (should be a multiple of 16), None
            defaults to the size of the conditioning
        seed: Set a seed for sampling
        output_name: where to save the output image, `{idx}` will be replaced
            by the index of the sample
        prompt: Prompt used for sampling
        device: Pytorch device
        num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
        loop: start an interactive session and sample multiple times
        guidance: guidance value used for guidance distillation
        add_sampling_metadata: Add the prompt to the image Exif metadata
        img_cond_path: path to conditioning image (jpeg/png/webp)
        trt: use TensorRT backend for optimized inference
        track_usage: track usage of the model for licensing purposes
    """
    assert name == "flux-dev-kontext", f"Got unknown model name: {name}"

    torch_device = torch.device(device)

    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    if aspect_ratio is None:
        width = None
        height = None
    else:
        width, height = aspect_ratio_to_height_width(aspect_ratio)

    if not trt:
        t5 = load_t5(torch_device, max_length=512)
        clip = load_clip(torch_device)
        model = load_flow_model(name, device="cpu" if offload else torch_device)
    else:
        # lazy import to make install optional
        from flux.trt.trt_manager import ModuleName, TRTManager

        # Check if we need ONNX model access (which requires authentication for FLUX models)
        onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)

        trt_ctx_manager = TRTManager(
            trt_transformer_precision=trt_transformer_precision,
            trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
        )
        engines = trt_ctx_manager.load_engines(
            model_name=name,
            module_names={
                ModuleName.CLIP,
                ModuleName.TRANSFORMER,
                ModuleName.T5,
            },
            engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
            custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
            trt_image_height=height,
            trt_image_width=width,
            trt_batch_size=1,
            trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
            trt_static_batch=False,
            trt_static_shape=False,
        )

        model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
        clip = engines[ModuleName.CLIP].to(torch_device)
        t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)

    ae = load_ae(name, device="cpu" if offload else torch_device)
    content_filter = PixtralContentFilter(torch.device("cpu"))

    rng = torch.Generator(device="cpu")
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
        img_cond_path=img_cond_path,
    )

    if loop:
        opts = parse_prompt(opts)
        opts = parse_img_cond_path(opts)

    while opts is not None:
        if opts.seed is None:
            opts.seed = rng.seed()
        print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
        t0 = time.perf_counter()

        if content_filter.test_txt(opts.prompt):
            print("Your prompt has been automatically flagged. Please choose another prompt.")
            if loop:
                print("-" * 80)
                opts = parse_prompt(opts)
                opts = parse_img_cond_path(opts)
            else:
                opts = None
            continue
        if content_filter.test_image(opts.img_cond_path):
            print("Your input image has been automatically flagged. Please choose another image.")
            if loop:
                print("-" * 80)
                opts = parse_prompt(opts)
                opts = parse_img_cond_path(opts)
            else:
                opts = None
            continue

        if offload:
            t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
        inp, height, width = prepare_kontext(
            t5=t5,
            clip=clip,
            prompt=opts.prompt,
            ae=ae,
            img_cond_path=opts.img_cond_path,
            target_width=opts.width,
            target_height=opts.height,
            bs=1,
            seed=opts.seed,
            device=torch_device,
        )
        from safetensors.torch import save_file

        save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft")
        inp.pop("img_cond_orig")
        opts.seed = None
        timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

        # offload TEs and AE to CPU, load model to gpu
        if offload:
            t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # denoise initial noise
        t00 = time.time()
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
        torch.cuda.synchronize()
        t01 = time.time()
        print(f"Denoising took {t01 - t00:.3f}s")

        # offload model, load autoencoder to gpu
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), height, width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            ae_dev_t0 = time.perf_counter()
            x = ae.decode(x)
            torch.cuda.synchronize()
            ae_dev_t1 = time.perf_counter()
            print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s")

        if content_filter.test_image(x.cpu()):
            print(
                "Your output image has been automatically flagged. Choose another prompt/image or try again."
            )
            if loop:
                print("-" * 80)
                opts = parse_prompt(opts)
                opts = parse_img_cond_path(opts)
            else:
                opts = None
            continue

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.perf_counter()
        print(f"Done in {t1 - t0:.1f}s")

        idx = save_image(
            None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
        )

        if loop:
            print("-" * 80)
            opts = parse_prompt(opts)
            opts = parse_img_cond_path(opts)
        else:
            opts = None


if __name__ == "__main__":
    Fire(main)


================================================
FILE: src/flux/cli_redux.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob

import torch
from fire import Fire
from transformers import pipeline

from flux.modules.image_embedders import ReduxImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
from flux.util import (
    get_checkpoint_path,
    load_ae,
    load_clip,
    load_flow_model,
    load_t5,
    save_image,
)


@dataclass
class SamplingOptions:
    prompt: str
    width: int
    height: int
    num_steps: int
    guidance: float
    seed: int | None
    img_cond_path: str


def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
    user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Leave this field empty to do nothing "
        "or write a command starting with a slash:\n"
        "- '/w <width>' will set the width of the generated image\n"
        "- '/h <height>' will set the height of the generated image\n"
        "- '/s <seed>' sets the next seed\n"
        "- '/g <guidance>' sets the guidance (flux-dev only)\n"
        "- '/n <steps>' sets the number of steps\n"
        "- '/q' to quit"
    )

    while (prompt := input(user_question)).startswith("/"):
        if prompt.startswith("/w"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, width = prompt.split()
            options.width = 16 * (int(width) // 16)
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height * options.width / 1e6:.2f}MP)"
            )
        elif prompt.startswith("/h"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, height = prompt.split()
            options.height = 16 * (int(height) // 16)
            print(
                f"Setting resolution to {options.width} x {options.height} "
                f"({options.height * options.width / 1e6:.2f}MP)"
            )
        elif prompt.startswith("/g"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, guidance = prompt.split()
            options.guidance = float(guidance)
            print(f"Setting guidance to {options.guidance}")
        elif prompt.startswith("/s"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, seed = prompt.split()
            options.seed = int(seed)
            print(f"Setting seed to {options.seed}")
        elif prompt.startswith("/n"):
            if prompt.count(" ") != 1:
                print(f"Got invalid command '{prompt}'\n{usage}")
                continue
            _, steps = prompt.split()
            options.num_steps = int(steps)
            print(f"Setting number of steps to {options.num_steps}")
        elif prompt.startswith("/q"):
            print("Quitting")
            return None
        else:
            if not prompt.startswith("/h"):
                print(f"Got invalid command '{prompt}'\n{usage}")
            print(usage)
    return options


def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
    if options is None:
        return None

    user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
    usage = (
        "Usage: Either write your prompt directly, leave this field empty "
        "to repeat the conditioning image or write a command starting with a slash:\n"
        "- '/q' to quit"
    )

    while True:
        img_cond_path = input(user_question)

        if img_cond_path.startswith("/"):
            if img_cond_path.startswith("/q"):
                print("Quitting")
                return None
            else:
                if not img_cond_path.startswith("/h"):
                    print(f"Got invalid command '{img_cond_path}'\n{usage}")
                print(usage)
            continue

        if img_cond_path == "":
            break

        if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
            (".jpg", ".jpeg", ".png", ".webp")
        ):
            print(f"File '{img_cond_path}' does not exist or is not a valid image file")
            continue

        options.img_cond_path = img_cond_path
        break

    return options


@torch.inference_mode()
def main(
    name: str = "flux-dev",
    width: int = 1360,
    height: int = 768,
    seed: int | None = None,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    num_steps: int | None = None,
    loop: bool = False,
    guidance: float = 2.5,
    offload: bool = False,
    output_dir: str = "output",
    add_sampling_metadata: bool = True,
    img_cond_path: str = "assets/robot.webp",
    track_usage: bool = False,
):
    """
    Sample the flux model. Either interactively (set `--loop`) or run for a
    single image.

    Args:
        name: Name of the base model to use (either 'flux-dev' or 'flux-schnell')
        height: height of the sample in pixels (should be a multiple of 16)
        width: width of the sample in pixels (should be a multiple of 16)
        seed: Set a seed for sampling
        device: Pytorch device
        num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
        loop: start an interactive session and sample multiple times
        guidance: guidance value used for guidance distillation
        offload: offload models to CPU when not in use
        output_dir: where to save the output images
        add_sampling_metadata: Add the prompt to the image Exif metadata
        img_cond_path: path to conditioning image (jpeg/png/webp)
        track_usage: track usage of the model for licensing purposes
    """

    nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)

    if name not in (available := ["flux-dev", "flux-schnell"]):
        raise ValueError(f"Got unknown model name: {name}, chose from {available}")

    torch_device = torch.device(device)
    if num_steps is None:
        num_steps = 4 if name == "flux-schnell" else 50

    output_name = os.path.join(output_dir, "img_{idx}.jpg")
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        idx = 0
    else:
        fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
        if len(fns) > 0:
            idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
        else:
            idx = 0

    # init all components
    t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
    clip = load_clip(torch_device)
    model = load_flow_model(name, device="cpu" if offload else torch_device)
    ae = load_ae(name, device="cpu" if offload else torch_device)

    # Download and initialize the Redux adapter
    redux_path = str(
        get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX")
    )
    img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path)

    rng = torch.Generator(device="cpu")
    prompt = ""
    opts = SamplingOptions(
        prompt=prompt,
        width=width,
        height=height,
        num_steps=num_steps,
        guidance=guidance,
        seed=seed,
        img_cond_path=img_cond_path,
    )

    if loop:
        opts = parse_prompt(opts)
        opts = parse_img_cond_path(opts)

    while opts is not None:
        if opts.seed is None:
            opts.seed = rng.seed()
        print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
        t0 = time.perf_counter()

        # prepare input
        x = get_noise(
            1,
            opts.height,
            opts.width,
            device=torch_device,
            dtype=torch.bfloat16,
            seed=opts.seed,
        )
        opts.seed = None
        if offload:
            ae = ae.cpu()
            torch.cuda.empty_cache()
            t5, clip = t5.to(torch_device), clip.to(torch_device)
        inp = prepare_redux(
            t5,
            clip,
            x,
            prompt=opts.prompt,
            encoder=img_embedder,
            img_cond_path=opts.img_cond_path,
        )
        timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))

        # offload TEs to CPU, load model to gpu
        if offload:
            t5, clip = t5.cpu(), clip.cpu()
            torch.cuda.empty_cache()
            model = model.to(torch_device)

        # denoise initial noise
        x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)

        # offload model, load autoencoder to gpu
        if offload:
            model.cpu()
            torch.cuda.empty_cache()
            ae.decoder.to(x.device)

        # decode latents to pixel space
        x = unpack(x.float(), opts.height, opts.width)
        with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
            x = ae.decode(x)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        t1 = time.perf_counter()
        print(f"Done in {t1 - t0:.1f}s")

        idx = save_image(
            nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
        )

        if loop:
            print("-" * 80)
            opts = parse_prompt(opts)
            opts = parse_img_cond_path(opts)
        else:
            opts = None


if __name__ == "__main__":
    Fire(main)


================================================
FILE: src/flux/content_filters.py
================================================
import torch
from einops import rearrange
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline

PROMPT_IMAGE_INTEGRITY = """
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.

Output: Respond with only "yes" or "no"

Criteria for "yes":
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
- The image displays a trademarked logo or brand
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)

Criteria for "no":
- All other cases
- When you cannot identify the specific copyrighted work or named individual

Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person depicted
2. General references to demographics or characteristics are not sufficient
3. Base your decision solely on visual content, not interpretation
4. Provide only the one-word answer: "yes" or "no"
""".strip()


PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"

PROMPT_TEXT_INTEGRITY = """
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.

Output: Respond with only "yes" or "no"

Criteria for "Yes":
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
- The prompt explicitly mentions a trademarked logo or brand
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)

Criteria for "No":
- All other cases
- When you cannot identify the specific copyrighted work or named individual

Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person referenced
2. General demographic descriptions or characteristics are not sufficient
3. Analyze only the prompt text, not potential image outcomes
4. Provide only the one-word answer: "yes" or "no"

The prompt to check is:
-----
{prompt}
-----

Does this prompt have copyright concerns or includes public figures?
""".strip()


class PixtralContentFilter(torch.nn.Module):
    def __init__(
        self,
        device: torch.device = torch.device("cpu"),
        nsfw_threshold: float = 0.85,
    ):
        super().__init__()

        model_id = "mistral-community/pixtral-12b"
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device)

        self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"])

        self.nsfw_classifier = pipeline(
            "image-classification", model="Falconsai/nsfw_image_detection", device=device
        )
        self.nsfw_threshold = nsfw_threshold

    def yes_no_logit_processor(
        self, input_ids: torch.LongTensor, scores: torch.FloatTensor
    ) -> torch.FloatTensor:
        """
        Sets all tokens but yes/no to the minimum.
        """
        scores_yes_token = scores[:, self.yes_token].clone()
        scores_no_token = scores[:, self.no_token].clone()
        scores_min = scores.min()
        scores[:, :] = scores_min - 1
        scores[:, self.yes_token] = scores_yes_token
        scores[:, self.no_token] = scores_no_token
        return scores

    def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
        if isinstance(image, torch.Tensor):
            image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
            image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
        elif isinstance(image, str):
            image = Image.open(image)

        classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
        if classification["score"] > self.nsfw_threshold:
            return True

        # 512^2 pixels are enough for checking
        w, h = image.size
        f = (512**2 / (w * h)) ** 0.5
        image = image.resize((int(f * w), int(f * h)))

        chat = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "content": PROMPT_IMAGE_INTEGRITY,
                    },
                    {
                        "type": "image",
                        "image": image,
                    },
                    {
                        "type": "text",
                        "content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
                    },
                ],
            }
        ]

        inputs = self.processor.apply_chat_template(
            chat,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device)

        generate_ids = self.model.generate(
            **inputs,
            max_new_tokens=1,
            logits_processor=[self.yes_no_logit_processor],
            do_sample=False,
        )
        return generate_ids[0, -1].item() == self.yes_token

    def test_txt(self, txt: str) -> bool:
        chat = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "content": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
                    },
                ],
            }
        ]

        inputs = self.processor.apply_chat_template(
            chat,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
        ).to(self.model.device)

        generate_ids = self.model.generate(
            **inputs,
            max_new_tokens=1,
            logits_processor=[self.yes_no_logit_processor],
            do_sample=False,
        )
        return generate_ids[0, -1].item() == self.yes_token


================================================
FILE: src/flux/math.py
================================================
import torch
from einops import rearrange
from torch import Tensor


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
    q, k = apply_rope(q, k, pe)

    x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
    x = rearrange(x, "B H L D -> B L (H D)")

    return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
    assert dim % 2 == 0
    scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
    omega = 1.0 / (theta**scale)
    out = torch.einsum("...n,d->...nd", pos, omega)
    out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
    out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
    return out.float()


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
    xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
    xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
    xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
    xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
    return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)


================================================
FILE: src/flux/model.py
================================================
from dataclasses import dataclass

import torch
from torch import Tensor, nn

from flux.modules.layers import (
    DoubleStreamBlock,
    EmbedND,
    LastLayer,
    MLPEmbedder,
    SingleStreamBlock,
    timestep_embedding,
)
from flux.modules.lora import LinearLora, replace_linear_with_lora


@dataclass
class FluxParams:
    in_channels: int
    out_channels: int
    vec_in_dim: int
    context_in_dim: int
    hidden_size: int
    mlp_ratio: float
    num_heads: int
    depth: int
    depth_single_blocks: int
    axes_dim: list[int]
    theta: int
    qkv_bias: bool
    guidance_embed: bool


class Flux(nn.Module):
    """
    Transformer model for flow matching on sequences.
    """

    def __init__(self, params: FluxParams):
        super().__init__()

        self.params = params
        self.in_channels = params.in_channels
        self.out_channels = params.out_channels
        if params.hidden_size % params.num_heads != 0:
            raise ValueError(
                f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
            )
        pe_dim = params.hidden_size // params.num_heads
        if sum(params.axes_dim) != pe_dim:
            raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
        self.hidden_size = params.hidden_size
        self.num_heads = params.num_heads
        self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
        self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
        self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
        self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
        self.guidance_in = (
            MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
        )
        self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)

        self.double_blocks = nn.ModuleList(
            [
                DoubleStreamBlock(
                    self.hidden_size,
                    self.num_heads,
                    mlp_ratio=params.mlp_ratio,
                    qkv_bias=params.qkv_bias,
                )
                for _ in range(params.depth)
            ]
        )

        self.single_blocks = nn.ModuleList(
            [
                SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
                for _ in range(params.depth_single_blocks)
            ]
        )

        self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)

    def forward(
        self,
        img: Tensor,
        img_ids: Tensor,
        txt: Tensor,
        txt_ids: Tensor,
        timesteps: Tensor,
        y: Tensor,
        guidance: Tensor | None = None,
    ) -> Tensor:
        if img.ndim != 3 or txt.ndim != 3:
            raise ValueError("Input img and txt tensors must have 3 dimensions.")

        # running on sequences img
        img = self.img_in(img)
        vec = self.time_in(timestep_embedding(timesteps, 256))
        if self.params.guidance_embed:
            if guidance is None:
                raise ValueError("Didn't get guidance strength for guidance distilled model.")
            vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
        vec = vec + self.vector_in(y)
        txt = self.txt_in(txt)

        ids = torch.cat((txt_ids, img_ids), dim=1)
        pe = self.pe_embedder(ids)

        for block in self.double_blocks:
            img, txt = block(img=img, txt=txt, vec=vec, pe=pe)

        img = torch.cat((txt, img), 1)
        for block in self.single_blocks:
            img = block(img, vec=vec, pe=pe)
        img = img[:, txt.shape[1] :, ...]

        img = self.final_layer(img, vec)  # (N, T, patch_size ** 2 * out_channels)
        return img


class FluxLoraWrapper(Flux):
    def __init__(
        self,
        lora_rank: int = 128,
        lora_scale: float = 1.0,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)

        self.lora_rank = lora_rank

        replace_linear_with_lora(
            self,
            max_rank=lora_rank,
            scale=lora_scale,
        )

    def set_lora_scale(self, scale: float) -> None:
        for module in self.modules():
            if isinstance(module, LinearLora):
                module.set_scale(scale=scale)


================================================
FILE: src/flux/modules/autoencoder.py
================================================
from dataclasses import dataclass

import torch
from einops import rearrange
from torch import Tensor, nn


@dataclass
class AutoEncoderParams:
    resolution: int
    in_channels: int
    ch: int
    out_ch: int
    ch_mult: list[int]
    num_res_blocks: int
    z_channels: int
    scale_factor: float
    shift_factor: float


def swish(x: Tensor) -> Tensor:
    return x * torch.sigmoid(x)


class AttnBlock(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        self.in_channels = in_channels

        self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)

        self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)

    def attention(self, h_: Tensor) -> Tensor:
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        b, c, h, w = q.shape
        q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
        k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
        v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
        h_ = nn.functional.scaled_dot_product_attention(q, k, v)

        return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)

    def forward(self, x: Tensor) -> Tensor:
        return x + self.proj_out(self.attention(x))


class ResnetBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels

        self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        if self.in_channels != self.out_channels:
            self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        h = x
        h = self.norm1(h)
        h = swish(h)
        h = self.conv1(h)

        h = self.norm2(h)
        h = swish(h)
        h = self.conv2(h)

        if self.in_channels != self.out_channels:
            x = self.nin_shortcut(x)

        return x + h


class Downsample(nn.Module):
    def __init__(self, in_channels: int):
        super().__init__()
        # no asymmetric padding in torch conv, must do it ourselves
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)

    def forward(self, x: Tensor):
        pad = (0, 1, 0, 1)
     
Download .txt
gitextract_5wlz5y8p/

├── .github/
│   └── workflows/
│       └── ci.yaml
├── .gitignore
├── LICENSE
├── README.md
├── demo_gr.py
├── demo_st.py
├── demo_st_fill.py
├── docs/
│   ├── fill.md
│   ├── image-editing.md
│   ├── image-variation.md
│   ├── structural-conditioning.md
│   └── text-to-image.md
├── model_cards/
│   ├── FLUX.1-Krea-dev.md
│   ├── FLUX.1-dev.md
│   ├── FLUX.1-kontext-dev.md
│   └── FLUX.1-schnell.md
├── model_licenses/
│   ├── LICENSE-FLUX1-dev
│   └── LICENSE-FLUX1-schnell
├── pyproject.toml
├── setup.py
└── src/
    └── flux/
        ├── __init__.py
        ├── __main__.py
        ├── cli.py
        ├── cli_control.py
        ├── cli_fill.py
        ├── cli_kontext.py
        ├── cli_redux.py
        ├── content_filters.py
        ├── math.py
        ├── model.py
        ├── modules/
        │   ├── autoencoder.py
        │   ├── conditioner.py
        │   ├── image_embedders.py
        │   ├── layers.py
        │   └── lora.py
        ├── sampling.py
        ├── trt/
        │   ├── __init__.py
        │   ├── engine/
        │   │   ├── __init__.py
        │   │   ├── base_engine.py
        │   │   ├── clip_engine.py
        │   │   ├── t5_engine.py
        │   │   ├── transformer_engine.py
        │   │   └── vae_engine.py
        │   ├── trt_config/
        │   │   ├── __init__.py
        │   │   ├── base_trt_config.py
        │   │   ├── clip_trt_config.py
        │   │   ├── t5_trt_config.py
        │   │   ├── transformer_trt_config.py
        │   │   └── vae_trt_config.py
        │   └── trt_manager.py
        └── util.py
Download .txt
SYMBOL INDEX (266 symbols across 29 files)

FILE: demo_gr.py
  function get_models (line 27) | def get_models(name: str, device: torch.device, offload: bool, is_schnel...
  class FluxGenerator (line 36) | class FluxGenerator:
    method __init__ (line 37) | def __init__(self, model_name: str, device: str, offload: bool, track_...
    method generate_image (line 51) | def generate_image(
  function create_demo (line 175) | def create_demo(

FILE: demo_st.py
  function get_models (line 32) | def get_models(name: str, device: torch.device, offload: bool, is_schnel...
  function get_image (line 41) | def get_image() -> torch.Tensor | None:
  function main (line 58) | def main(
  function app (line 292) | def app():

FILE: demo_st_fill.py
  function add_border_and_mask (line 31) | def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, ...
  function get_models (line 83) | def get_models(name: str, device: torch.device, offload: bool):
  function resize (line 92) | def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -...
  function clear_canvas_state (line 116) | def clear_canvas_state():
  function set_new_image (line 124) | def set_new_image(img: Image.Image):
  function downscale_image (line 131) | def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image:
  function main (line 148) | def main(
  function app (line 497) | def app():

FILE: src/flux/cli.py
  class SamplingOptions (line 26) | class SamplingOptions:
  function parse_prompt (line 35) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
  function main (line 103) | def main(

FILE: src/flux/cli_control.py
  class SamplingOptions (line 17) | class SamplingOptions:
  function parse_prompt (line 28) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
  function parse_img_cond_path (line 95) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
  function parse_lora_scale (line 134) | def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingO...
  function main (line 162) | def main(

FILE: src/flux/cli_fill.py
  class SamplingOptions (line 17) | class SamplingOptions:
  function parse_prompt (line 28) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
  function parse_img_cond_path (line 73) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
  function parse_img_mask_path (line 119) | def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOpti...
  function main (line 175) | def main(

FILE: src/flux/cli_kontext.py
  class SamplingOptions (line 24) | class SamplingOptions:
  function parse_prompt (line 34) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
  function parse_img_cond_path (line 108) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
  function main (line 150) | def main(

FILE: src/flux/cli_redux.py
  class SamplingOptions (line 24) | class SamplingOptions:
  function parse_prompt (line 34) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
  function parse_img_cond_path (line 99) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
  function main (line 139) | def main(

FILE: src/flux/content_filters.py
  class PixtralContentFilter (line 59) | class PixtralContentFilter(torch.nn.Module):
    method __init__ (line 60) | def __init__(
    method yes_no_logit_processor (line 78) | def yes_no_logit_processor(
    method test_image (line 92) | def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
    method test_txt (line 144) | def test_txt(self, txt: str) -> bool:

FILE: src/flux/math.py
  function attention (line 6) | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
  function rope (line 15) | def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
  function apply_rope (line 25) | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tenso...

FILE: src/flux/model.py
  class FluxParams (line 18) | class FluxParams:
  class Flux (line 34) | class Flux(nn.Module):
    method __init__ (line 39) | def __init__(self, params: FluxParams):
    method forward (line 84) | def forward(
  class FluxLoraWrapper (line 122) | class FluxLoraWrapper(Flux):
    method __init__ (line 123) | def __init__(
    method set_lora_scale (line 140) | def set_lora_scale(self, scale: float) -> None:

FILE: src/flux/modules/autoencoder.py
  class AutoEncoderParams (line 9) | class AutoEncoderParams:
  function swish (line 21) | def swish(x: Tensor) -> Tensor:
  class AttnBlock (line 25) | class AttnBlock(nn.Module):
    method __init__ (line 26) | def __init__(self, in_channels: int):
    method attention (line 37) | def attention(self, h_: Tensor) -> Tensor:
    method forward (line 51) | def forward(self, x: Tensor) -> Tensor:
  class ResnetBlock (line 55) | class ResnetBlock(nn.Module):
    method __init__ (line 56) | def __init__(self, in_channels: int, out_channels: int):
    method forward (line 69) | def forward(self, x):
  class Downsample (line 85) | class Downsample(nn.Module):
    method __init__ (line 86) | def __init__(self, in_channels: int):
    method forward (line 91) | def forward(self, x: Tensor):
  class Upsample (line 98) | class Upsample(nn.Module):
    method __init__ (line 99) | def __init__(self, in_channels: int):
    method forward (line 103) | def forward(self, x: Tensor):
  class Encoder (line 109) | class Encoder(nn.Module):
    method __init__ (line 110) | def __init__(
    method forward (line 159) | def forward(self, x: Tensor) -> Tensor:
  class Decoder (line 183) | class Decoder(nn.Module):
    method __init__ (line 184) | def __init__(
    method forward (line 237) | def forward(self, z: Tensor) -> Tensor:
  class DiagonalGaussian (line 267) | class DiagonalGaussian(nn.Module):
    method __init__ (line 268) | def __init__(self, sample: bool = True, chunk_dim: int = 1):
    method forward (line 273) | def forward(self, z: Tensor) -> Tensor:
  class AutoEncoder (line 282) | class AutoEncoder(nn.Module):
    method __init__ (line 283) | def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
    method encode (line 308) | def encode(self, x: Tensor) -> Tensor:
    method decode (line 313) | def decode(self, z: Tensor) -> Tensor:
    method forward (line 317) | def forward(self, x: Tensor) -> Tensor:

FILE: src/flux/modules/conditioner.py
  class HFEmbedder (line 5) | class HFEmbedder(nn.Module):
    method __init__ (line 6) | def __init__(self, version: str, max_length: int, **hf_kwargs):
    method forward (line 21) | def forward(self, text: list[str]) -> Tensor:

FILE: src/flux/modules/image_embedders.py
  class DepthImageEncoder (line 13) | class DepthImageEncoder:
    method __init__ (line 16) | def __init__(self, device):
    method __call__ (line 21) | def __call__(self, img: torch.Tensor) -> torch.Tensor:
  class CannyImageEncoder (line 36) | class CannyImageEncoder:
    method __init__ (line 37) | def __init__(
    method __call__ (line 47) | def __call__(self, img: torch.Tensor) -> torch.Tensor:
  class ReduxImageEncoder (line 64) | class ReduxImageEncoder(nn.Module):
    method __init__ (line 67) | def __init__(
    method __call__ (line 92) | def __call__(self, x: Image.Image) -> torch.Tensor:

FILE: src/flux/modules/layers.py
  class EmbedND (line 11) | class EmbedND(nn.Module):
    method __init__ (line 12) | def __init__(self, dim: int, theta: int, axes_dim: list[int]):
    method forward (line 18) | def forward(self, ids: Tensor) -> Tensor:
  function timestep_embedding (line 28) | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: fl...
  class MLPEmbedder (line 52) | class MLPEmbedder(nn.Module):
    method __init__ (line 53) | def __init__(self, in_dim: int, hidden_dim: int):
    method forward (line 59) | def forward(self, x: Tensor) -> Tensor:
  class RMSNorm (line 63) | class RMSNorm(torch.nn.Module):
    method __init__ (line 64) | def __init__(self, dim: int):
    method forward (line 68) | def forward(self, x: Tensor):
  class QKNorm (line 75) | class QKNorm(torch.nn.Module):
    method __init__ (line 76) | def __init__(self, dim: int):
    method forward (line 81) | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Te...
  class SelfAttention (line 87) | class SelfAttention(nn.Module):
    method __init__ (line 88) | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
    method forward (line 97) | def forward(self, x: Tensor, pe: Tensor) -> Tensor:
  class ModulationOut (line 107) | class ModulationOut:
  class Modulation (line 113) | class Modulation(nn.Module):
    method __init__ (line 114) | def __init__(self, dim: int, double: bool):
    method forward (line 120) | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut |...
  class DoubleStreamBlock (line 129) | class DoubleStreamBlock(nn.Module):
    method __init__ (line 130) | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float,...
    method forward (line 158) | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -...
  class SingleStreamBlock (line 194) | class SingleStreamBlock(nn.Module):
    method __init__ (line 200) | def __init__(
    method forward (line 227) | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
  class LastLayer (line 242) | class LastLayer(nn.Module):
    method __init__ (line 243) | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
    method forward (line 249) | def forward(self, x: Tensor, vec: Tensor) -> Tensor:

FILE: src/flux/modules/lora.py
  function replace_linear_with_lora (line 5) | def replace_linear_with_lora(
  class LinearLora (line 34) | class LinearLora(nn.Linear):
    method __init__ (line 35) | def __init__(
    method set_scale (line 84) | def set_scale(self, scale: float) -> None:
    method forward (line 88) | def forward(self, input: torch.Tensor) -> torch.Tensor:

FILE: src/flux/sampling.py
  function get_noise (line 17) | def get_noise(
  function prepare (line 36) | def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str |...
  function prepare_control (line 70) | def prepare_control(
  function prepare_fill (line 107) | def prepare_fill(
  function prepare_redux (line 160) | def prepare_redux(
  function prepare_kontext (line 210) | def prepare_kontext(
  function time_shift (line 277) | def time_shift(mu: float, sigma: float, t: Tensor):
  function get_lin_function (line 281) | def get_lin_function(
  function get_schedule (line 289) | def get_schedule(
  function denoise (line 308) | def denoise(
  function unpack (line 356) | def unpack(x: Tensor, height: int, width: int) -> Tensor:

FILE: src/flux/trt/engine/base_engine.py
  class SharedMemory (line 32) | class SharedMemory(object):
    method __new__ (line 33) | def __new__(cls, *args, **kwargs):
    method __init__ (line 39) | def __init__(self, size: int, device=torch.device("cuda")):
    method resize (line 48) | def resize(self, name: str, size: int):
    method reset (line 54) | def reset(self, name: str):
    method deallocate (line 61) | def deallocate(self):
    method shared_device_memory (line 72) | def shared_device_memory(self):
    method __str__ (line 75) | def __str__(self):
  class BaseEngine (line 99) | class BaseEngine(ABC):
    method trt_datatype_to_torch (line 101) | def trt_datatype_to_torch(datatype):
    method cpu (line 118) | def cpu(self) -> "BaseEngine":
    method cuda (line 122) | def cuda(self) -> "BaseEngine":
    method to (line 126) | def to(self, device: str | torch.device) -> "BaseEngine":
  class Engine (line 130) | class Engine(BaseEngine):
    method __init__ (line 131) | def __init__(
    method __call__ (line 157) | def __call__(self, *args, **Kwargs) -> torch.Tensor | dict[str, torch....
    method cpu (line 160) | def cpu(self) -> "Engine":
    method cuda (line 171) | def cuda(self) -> "Engine":
    method to (line 182) | def to(self, device: str | torch.device) -> "Engine":
    method deactivate (line 194) | def deactivate(self):
    method allocate_buffers (line 198) | def allocate_buffers(
    method get_dtype (line 220) | def get_dtype(self, tensor_name: str):
    method override_shapes (line 223) | def override_shapes(self, feed_dict: Dict[str, torch.Tensor]):
    method deallocate_buffers (line 249) | def deallocate_buffers(self):
    method device_memory_size (line 258) | def device_memory_size(self):
    method calculate_input_hash (line 267) | def calculate_input_hash(feed_dict: Dict[str, torch.Tensor]):
    method _capture_cuda_graph (line 270) | def _capture_cuda_graph(self):
    method infer (line 280) | def infer(
    method __str__ (line 297) | def __str__(self):

FILE: src/flux/trt/engine/clip_engine.py
  class CLIPEngine (line 24) | class CLIPEngine(Engine):
    method __init__ (line 25) | def __init__(self, trt_config: ClipConfig, stream: torch.cuda.Stream, ...
    method __call__ (line 33) | def __call__(

FILE: src/flux/trt/engine/t5_engine.py
  class T5Engine (line 24) | class T5Engine(Engine):
    method __init__ (line 25) | def __init__(self, trt_config: T5Config, stream: torch.cuda.Stream, **...
    method __call__ (line 33) | def __call__(

FILE: src/flux/trt/engine/transformer_engine.py
  class TransformerEngine (line 23) | class TransformerEngine(Engine):
    method __init__ (line 46) | def __init__(self, trt_config: TransformerConfig, stream: torch.cuda.S...
    method dd_to_flux (line 50) | def dd_to_flux(self):
    method flux_to_dd (line 54) | def flux_to_dd(self):
    method __call__ (line 58) | def __call__(

FILE: src/flux/trt/engine/vae_engine.py
  class VAEDecoder (line 23) | class VAEDecoder(Engine):
    method __init__ (line 24) | def __init__(self, trt_config: VAEDecoderConfig, stream: torch.cuda.St...
    method __call__ (line 28) | def __call__(
  class VAEEncoder (line 39) | class VAEEncoder(Engine):
    method __init__ (line 40) | def __init__(self, trt_config: VAEEncoderConfig, stream: torch.cuda.St...
    method __call__ (line 44) | def __call__(
  class VAEEngine (line 54) | class VAEEngine(BaseEngine):
    method __init__ (line 55) | def __init__(
    method decode (line 64) | def decode(self, z: torch.Tensor) -> torch.Tensor:
    method encode (line 67) | def encode(self, x: torch.Tensor) -> torch.Tensor:
    method cpu (line 71) | def cpu(self):
    method cuda (line 77) | def cuda(self):
    method to (line 83) | def to(self, device):
    method device_memory_size (line 90) | def device_memory_size(self):

FILE: src/flux/trt/trt_config/base_trt_config.py
  class ModuleName (line 30) | class ModuleName(Enum):
  class TRTBaseConfig (line 42) | class TRTBaseConfig:
    method build_trt_engine (line 69) | def build_trt_engine(
    method from_args (line 174) | def from_args(cls, model_name: str, *args, **kwargs) -> Any:
    method get_input_profile (line 178) | def get_input_profile(
    method check_dims (line 209) | def check_dims(self, *args, **kwargs) -> None | tuple[int, int] | int:
    method _check_batch (line 213) | def _check_batch(self, batch_size):
    method __post_init__ (line 218) | def __post_init__(self):
    method _get_onnx_path (line 223) | def _get_onnx_path(self) -> str:
    method _get_engine_path (line 232) | def _get_engine_path(self) -> str:
    method _get_repo_id (line 240) | def _get_repo_id(model_name: str) -> str:
  function register_config (line 255) | def register_config(module_name: ModuleName, precision: str):
  function get_config (line 266) | def get_config(module_name: ModuleName, precision: str) -> TRTBaseConfig:

FILE: src/flux/trt/trt_config/clip_trt_config.py
  class ClipConfig (line 25) | class ClipConfig(TRTBaseConfig):
    method from_args (line 35) | def from_args(
    method check_dims (line 48) | def check_dims(self, batch_size: int) -> None:
    method get_input_profile (line 51) | def get_input_profile(

FILE: src/flux/trt/trt_config/t5_trt_config.py
  class T5Config (line 29) | class T5Config(TRTBaseConfig):
    method from_args (line 40) | def from_args(
    method check_dims (line 53) | def check_dims(self, batch_size: int) -> None:
    method get_input_profile (line 56) | def get_input_profile(
    method _get_onnx_path (line 74) | def _get_onnx_path(self) -> str:

FILE: src/flux/trt/trt_config/transformer_trt_config.py
  class TransformerConfig (line 31) | class TransformerConfig(TRTBaseConfig):
    method from_args (line 57) | def from_args(
    method _get_onnx_path (line 87) | def _get_onnx_path(self) -> str:
    method _get_latent (line 99) | def _get_latent(image_dim: int, compression_factor: int) -> int:
    method _get_context_dim (line 103) | def _get_context_dim(
    method __post_init__ (line 118) | def __post_init__(self):
    method get_minmax_dims (line 160) | def get_minmax_dims(
    method check_dims (line 194) | def check_dims(
    method get_input_profile (line 219) | def get_input_profile(

FILE: src/flux/trt/trt_config/vae_trt_config.py
  class VAEBaseConfig (line 25) | class VAEBaseConfig(TRTBaseConfig):
    method _get_latent_dim (line 38) | def _get_latent_dim(self, image_dim: int) -> int:
    method __post_init__ (line 41) | def __post_init__(self):
    method check_dims (line 46) | def check_dims(
  class VAEDecoderConfig (line 71) | class VAEDecoderConfig(VAEBaseConfig):
    method from_args (line 79) | def from_args(
    method get_minmax_dims (line 102) | def get_minmax_dims(
    method get_input_profile (line 128) | def get_input_profile(
  class VAEEncoderConfig (line 175) | class VAEEncoderConfig(VAEBaseConfig):
    method from_args (line 183) | def from_args(cls, model_name: str, **kwargs):
    method get_minmax_dims (line 206) | def get_minmax_dims(
    method get_input_profile (line 229) | def get_input_profile(

FILE: src/flux/trt/trt_manager.py
  class TRTManager (line 46) | class TRTManager:
    method module_to_engine_class (line 48) | def module_to_engine_class(self) -> dict[ModuleName, type[Engine]]:
    method __init__ (line 57) | def __init__(
    method _parse_models_precisions (line 76) | def _parse_models_precisions(
    method _parse_custom_onnx_path (line 99) | def _parse_custom_onnx_path(custom_onnx_paths: str) -> dict[ModuleName...
    method _create_directories (line 128) | def _create_directories(engine_dir: str):
    method _get_trt_configs (line 132) | def _get_trt_configs(
    method _build_engine (line 179) | def _build_engine(
    method load_engines (line 211) | def load_engines(
    method _clean_memory (line 279) | def _clean_memory():
    method init_runtime (line 283) | def init_runtime(self):
    method stop_runtime (line 290) | def stop_runtime(self):

FILE: src/flux/util.py
  function ensure_hf_auth (line 27) | def ensure_hf_auth():
  function prompt_for_hf_auth (line 45) | def prompt_for_hf_auth():
  function get_checkpoint_path (line 64) | def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path:
  function download_onnx_models_for_trt (line 108) | def download_onnx_models_for_trt(model_name: str, trt_transformer_precis...
  function check_onnx_access_for_trt (line 201) | def check_onnx_access_for_trt(model_name: str, trt_transformer_precision...
  function track_usage_via_api (line 206) | def track_usage_via_api(name: str, n=1) -> None:
  function save_image (line 243) | def save_image(
  class ModelSpec (line 288) | class ModelSpec:
  function aspect_ratio_to_height_width (line 637) | def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2)...
  function print_load_warning (line 646) | def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
  function load_flow_model (line 657) | def load_flow_model(name: str, device: str | torch.device = "cuda", verb...
  function load_t5 (line 689) | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) ...
  function load_clip (line 694) | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
  function load_ae (line 698) | def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder:
  function optionally_expand_state_dict (line 714) | def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dic...
  class WatermarkEmbedder (line 733) | class WatermarkEmbedder:
    method __init__ (line 734) | def __init__(self, watermark):
    method __call__ (line 740) | def __call__(self, image: torch.Tensor) -> torch.Tensor:
Condensed preview — 51 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (342K chars).
[
  {
    "path": ".github/workflows/ci.yaml",
    "chars": 545,
    "preview": "name: CI\non: push\njobs:\n  lint:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v2\n      - uses: ac"
  },
  {
    "path": ".gitignore",
    "chars": 3706,
    "preview": "# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python\n# Edit at https"
  },
  {
    "path": "LICENSE",
    "chars": 11357,
    "preview": "                                 Apache License\n                           Version 2.0, January 2004\n                   "
  },
  {
    "path": "README.md",
    "chars": 6721,
    "preview": "# FLUX\nby Black Forest Labs: https://bfl.ai.\n\nDocumentation for our API can be found here: [docs.bfl.ai](https://docs.bf"
  },
  {
    "path": "demo_gr.py",
    "chars": 9550,
    "preview": "import os\nimport time\nimport uuid\n\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom"
  },
  {
    "path": "demo_st.py",
    "chars": 9874,
    "preview": "import os\nimport re\nimport time\nfrom glob import iglob\nfrom io import BytesIO\n\nimport streamlit as st\nimport torch\nfrom "
  },
  {
    "path": "demo_st_fill.py",
    "chars": 17944,
    "preview": "import os\nimport re\nimport tempfile\nimport time\nfrom glob import iglob\nfrom io import BytesIO\n\nimport numpy as np\nimport"
  },
  {
    "path": "docs/fill.md",
    "chars": 1882,
    "preview": "## Open-weight models\n\nFLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless e"
  },
  {
    "path": "docs/image-editing.md",
    "chars": 2027,
    "preview": "## Open-weight models\n\nWe currently offer two open-weight text-to-image models.\n\n| Name                      | HuggingFa"
  },
  {
    "path": "docs/image-variation.md",
    "chars": 1532,
    "preview": "## Models\n\nFLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which"
  },
  {
    "path": "docs/structural-conditioning.md",
    "chars": 3506,
    "preview": "## Models\n\nStructural conditioning uses canny edge or depth detection to maintain precise control during image transform"
  },
  {
    "path": "docs/text-to-image.md",
    "chars": 4150,
    "preview": "## Open-weight models\n\nWe currently offer two open-weight text-to-image models.\n\n| Name                      | HuggingFa"
  },
  {
    "path": "model_cards/FLUX.1-Krea-dev.md",
    "chars": 6477,
    "preview": "![FLUX.1 Krea [dev] Grid](../assets/flux-1-krea-dev-grid.png)\n\n`FLUX.1 Krea [dev]` is a 12 billion parameter rectified f"
  },
  {
    "path": "model_cards/FLUX.1-dev.md",
    "chars": 2933,
    "preview": "![FLUX.1 [dev] Grid](../assets/dev_grid.jpg)\n\n`FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capabl"
  },
  {
    "path": "model_cards/FLUX.1-kontext-dev.md",
    "chars": 7610,
    "preview": "![FLUX.1 [dev] Grid](../assets/docs/kontext.png)\n\n`FLUX.1 Kontext [dev]` is a 12 billion parameter rectified flow transf"
  },
  {
    "path": "model_cards/FLUX.1-schnell.md",
    "chars": 2665,
    "preview": "![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg)\n\n`FLUX.1 [schnell]` is a 12 billion parameter rectified flow transf"
  },
  {
    "path": "model_licenses/LICENSE-FLUX1-dev",
    "chars": 18491,
    "preview": "FLUX.1 [dev] Non-Commercial License v1.1.1\n\nBlack Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make avail"
  },
  {
    "path": "model_licenses/LICENSE-FLUX1-schnell",
    "chars": 9155,
    "preview": "\n\nApache License\nVersion 2.0, January 2004\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, "
  },
  {
    "path": "pyproject.toml",
    "chars": 2239,
    "preview": "[project]\nname = \"flux\"\nauthors = [\n  { name = \"Black Forest Labs\", email = \"support@blackforestlabs.ai\" },\n]\ndescriptio"
  },
  {
    "path": "setup.py",
    "chars": 38,
    "preview": "import setuptools\n\nsetuptools.setup()\n"
  },
  {
    "path": "src/flux/__init__.py",
    "chars": 345,
    "preview": "try:\n    from ._version import (\n        version as __version__,  # type: ignore\n        version_tuple,\n    )\nexcept Imp"
  },
  {
    "path": "src/flux/__main__.py",
    "chars": 462,
    "preview": "from fire import Fire\n\nfrom .cli import main as cli_main\nfrom .cli_control import main as control_main\nfrom .cli_fill im"
  },
  {
    "path": "src/flux/cli.py",
    "chars": 10486,
    "preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
  },
  {
    "path": "src/flux/cli_control.py",
    "chars": 13733,
    "preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
  },
  {
    "path": "src/flux/cli_fill.py",
    "chars": 11209,
    "preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
  },
  {
    "path": "src/flux/cli_kontext.py",
    "chars": 13165,
    "preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
  },
  {
    "path": "src/flux/cli_redux.py",
    "chars": 9767,
    "preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
  },
  {
    "path": "src/flux/content_filters.py",
    "chars": 5959,
    "preview": "import torch\nfrom einops import rearrange\nfrom PIL import Image\nfrom transformers import AutoProcessor, LlavaForConditio"
  },
  {
    "path": "src/flux/math.py",
    "chars": 1166,
    "preview": "import torch\nfrom einops import rearrange\nfrom torch import Tensor\n\n\ndef attention(q: Tensor, k: Tensor, v: Tensor, pe: "
  },
  {
    "path": "src/flux/model.py",
    "chars": 4399,
    "preview": "from dataclasses import dataclass\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom flux.modules.layers import (\n    Doub"
  },
  {
    "path": "src/flux/modules/autoencoder.py",
    "chars": 10652,
    "preview": "from dataclasses import dataclass\n\nimport torch\nfrom einops import rearrange\nfrom torch import Tensor, nn\n\n\n@dataclass\nc"
  },
  {
    "path": "src/flux/modules/conditioner.py",
    "chars": 1502,
    "preview": "from torch import Tensor, nn\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer\n\n\nclass "
  },
  {
    "path": "src/flux/modules/image_embedders.py",
    "chars": 3465,
    "preview": "import cv2\nimport numpy as np\nimport torch\nfrom einops import rearrange, repeat\nfrom PIL import Image\nfrom safetensors.t"
  },
  {
    "path": "src/flux/modules/layers.py",
    "chars": 9374,
    "preview": "import math\nfrom dataclasses import dataclass\n\nimport torch\nfrom einops import rearrange\nfrom torch import Tensor, nn\n\nf"
  },
  {
    "path": "src/flux/modules/lora.py",
    "chars": 2545,
    "preview": "import torch\nfrom torch import nn\n\n\ndef replace_linear_with_lora(\n    module: nn.Module,\n    max_rank: int,\n    scale: f"
  },
  {
    "path": "src/flux/sampling.py",
    "chars": 11434,
    "preview": "import math\nfrom typing import Callable\n\nimport numpy as np\nimport torch\nfrom einops import rearrange, repeat\nfrom PIL i"
  },
  {
    "path": "src/flux/trt/__init__.py",
    "chars": 127,
    "preview": "from flux.trt.trt_config import ModuleName\nfrom flux.trt.trt_manager import TRTManager\n\n__all__ = [\"TRTManager\", \"Module"
  },
  {
    "path": "src/flux/trt/engine/__init__.py",
    "chars": 1175,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/engine/base_engine.py",
    "chars": 11083,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/engine/clip_engine.py",
    "chars": 1853,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/engine/t5_engine.py",
    "chars": 1826,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/engine/transformer_engine.py",
    "chars": 2491,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/engine/vae_engine.py",
    "chars": 3163,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_config/__init__.py",
    "chars": 1261,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_config/base_trt_config.py",
    "chars": 11052,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_config/clip_trt_config.py",
    "chars": 2105,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_config/t5_trt_config.py",
    "chars": 2736,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_config/transformer_trt_config.py",
    "chars": 10336,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_config/vae_trt_config.py",
    "chars": 9481,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/trt/trt_manager.py",
    "chars": 10243,
    "preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
  },
  {
    "path": "src/flux/util.py",
    "chars": 26452,
    "preview": "import getpass\nimport math\nimport os\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport requests\nimport "
  }
]

About this extraction

This page contains the full source code of the black-forest-labs/flux GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 51 files (319.8 KB), approximately 78.9k tokens, and a symbol index with 266 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!