Showing preview only (336K chars total). Download the full file or copy to clipboard to get everything.
Repository: black-forest-labs/flux
Branch: main
Commit: 802fb4713906
Files: 51
Total size: 319.8 KB
Directory structure:
gitextract_5wlz5y8p/
├── .github/
│ └── workflows/
│ └── ci.yaml
├── .gitignore
├── LICENSE
├── README.md
├── demo_gr.py
├── demo_st.py
├── demo_st_fill.py
├── docs/
│ ├── fill.md
│ ├── image-editing.md
│ ├── image-variation.md
│ ├── structural-conditioning.md
│ └── text-to-image.md
├── model_cards/
│ ├── FLUX.1-Krea-dev.md
│ ├── FLUX.1-dev.md
│ ├── FLUX.1-kontext-dev.md
│ └── FLUX.1-schnell.md
├── model_licenses/
│ ├── LICENSE-FLUX1-dev
│ └── LICENSE-FLUX1-schnell
├── pyproject.toml
├── setup.py
└── src/
└── flux/
├── __init__.py
├── __main__.py
├── cli.py
├── cli_control.py
├── cli_fill.py
├── cli_kontext.py
├── cli_redux.py
├── content_filters.py
├── math.py
├── model.py
├── modules/
│ ├── autoencoder.py
│ ├── conditioner.py
│ ├── image_embedders.py
│ ├── layers.py
│ └── lora.py
├── sampling.py
├── trt/
│ ├── __init__.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── base_engine.py
│ │ ├── clip_engine.py
│ │ ├── t5_engine.py
│ │ ├── transformer_engine.py
│ │ └── vae_engine.py
│ ├── trt_config/
│ │ ├── __init__.py
│ │ ├── base_trt_config.py
│ │ ├── clip_trt_config.py
│ │ ├── t5_trt_config.py
│ │ ├── transformer_trt_config.py
│ │ └── vae_trt_config.py
│ └── trt_manager.py
└── util.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/ci.yaml
================================================
name: CI
on: push
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install ruff==0.6.8
- name: Run Ruff
run: ruff check --output-format=github .
- name: Check imports
run: ruff check --select I --output-format=github .
- name: Check formatting
run: ruff format --check .
================================================
FILE: .gitignore
================================================
# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
# Edit at https://www.toptal.com/developers/gitignore?templates=linux,windows,macos,visualstudiocode,python
### Linux ###
*~
# temporary files which can be created if a process still has a handle open of a deleted file
.fuse_hidden*
# KDE directory preferences
.directory
# Linux trash folder which might appear on any partition or disk
.Trash-*
# .nfs files are created when an open file is removed but is still being accessed
.nfs*
### macOS ###
# General
.DS_Store
.AppleDouble
.LSOverride
# Icon must end with two \r
Icon
# Thumbnails
._*
# Files that might appear in the root of a volume
.DocumentRevisions-V100
.fseventsd
.Spotlight-V100
.TemporaryItems
.Trashes
.VolumeIcon.icns
.com.apple.timemachine.donotpresent
# Directories potentially created on remote AFP share
.AppleDB
.AppleDesktop
Network Trash Folder
Temporary Items
.apdisk
### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
### VisualStudioCode ###
.vscode/*
!.vscode/settings.json
!.vscode/tasks.json
!.vscode/launch.json
!.vscode/extensions.json
*.code-workspace
# Local History for Visual Studio Code
.history/
### VisualStudioCode Patch ###
# Ignore all local history of files
.history
.ionide
### Windows ###
# Windows thumbnail cache files
Thumbs.db
Thumbs.db:encryptable
ehthumbs.db
ehthumbs_vista.db
# Dump file
*.stackdump
# Folder config file
[Dd]esktop.ini
# Recycle Bin used on file shares
$RECYCLE.BIN/
# Windows Installer files
*.cab
*.msi
*.msix
*.msm
*.msp
# Windows shortcuts
*.lnk
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
output/
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# FLUX
by Black Forest Labs: https://bfl.ai.
Documentation for our API can be found here: [docs.bfl.ai](https://docs.bfl.ai/).

This repo contains minimal inference code to run image generation & editing with our Flux open-weight models.
## Local installation
```bash
cd $HOME && git clone https://github.com/black-forest-labs/flux
cd $HOME/flux
python3.10 -m venv .venv
source .venv/bin/activate
pip install -e ".[all]"
```
### Local installation with TensorRT support
If you would like to install the repository with [TensorRT](https://github.com/NVIDIA/TensorRT) support, you currently need to install a PyTorch image from NVIDIA instead. First install [enroot](https://github.com/NVIDIA/enroot), next follow the steps below:
```bash
cd $HOME && git clone https://github.com/black-forest-labs/flux
enroot import 'docker://$oauthtoken@nvcr.io#nvidia/pytorch:25.01-py3'
enroot create -n pti2501 nvidia+pytorch+25.01-py3.sqsh
enroot start --rw -m ${PWD}/flux:/workspace/flux -r pti2501
cd flux
pip install -e ".[tensorrt]" --extra-index-url https://pypi.nvidia.com
```
### Open-weight models
We are offering an extensive suite of open-weight models. For more information about the individual models, please refer to the link under **Usage**.
| Name | Usage | HuggingFace repo | License |
| --------------------------- | ---------------------------------------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- |
| `FLUX.1 [schnell]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](model_licenses/LICENSE-FLUX1-schnell) |
| `FLUX.1 [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Fill [dev]` | [In/Out-painting](docs/fill.md) | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Canny [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Depth [dev]` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Canny [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Depth [dev] LoRA` | [Structural Conditioning](docs/structural-conditioning.md) | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Redux [dev]` | [Image variation](docs/image-variation.md) | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Kontext [dev]` | [Image editing](docs/image-editing.md) | https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
| `FLUX.1 Krea [dev]` | [Text to Image](docs/text-to-image.md) | https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) |
The weights of the autoencoder are also released under [apache-2.0](https://huggingface.co/datasets/choosealicense/licenses/blob/main/markdown/apache-2.0.md) and can be found in the HuggingFace repos above.
## API usage
Our API offers access to all models including our Pro tier non-open weight models. Check out our API documentation [docs.bfl.ai](https://docs.bfl.ai/) to learn more.
## Licensing models for commercial use
You can license our models for commercial use here: https://bfl.ai/pricing/licensing
As the fee is based on a monthly usage, we provide code to automatically track your usage via the BFL API. To enable usage tracking please select *track_usage* in the cli or click the corresponding checkmark in our provided demos.
### Example: Using FLUX.1 Kontext with usage tracking
We provide a reference implementation for running FLUX.1 with usage tracking enabled for commercial licensing.
This can be customized as needed as long as the usage reporting is accurate.
For the reporting logic to work you will need to set your API key as an environment variable before running:
```bash
export BFL_API_KEY="your_api_key_here"
```
You can call `FLUX.1 Kontext [dev]` like this with tracking activated:
```bash
python -m flux kontext --track_usage --loop
```
For a single generation:
```bash
python -m flux kontext --track_usage --prompt "replace the logo with the text 'Black Forest Labs'"
```
The above reporting logic works similarly for FLUX.1 [dev] and FLUX.1 Tools [dev].
**Note that this is only required when using one or more of our open weights models commercially. More information on the commercial licensing can be found at the [BFL Helpdesk](https://help.bfl.ai/collections/6939000511-licensing).**
## Citation
If you find the provided code or models useful for your research, consider citing them as:
```bib
@misc{labs2025flux1kontextflowmatching,
title={FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space},
author={Black Forest Labs and Stephen Batifol and Andreas Blattmann and Frederic Boesel and Saksham Consul and Cyril Diagne and Tim Dockhorn and Jack English and Zion English and Patrick Esser and Sumith Kulal and Kyle Lacey and Yam Levi and Cheng Li and Dominik Lorenz and Jonas Müller and Dustin Podell and Robin Rombach and Harry Saini and Axel Sauer and Luke Smith},
year={2025},
eprint={2506.15742},
archivePrefix={arXiv},
primaryClass={cs.GR},
url={https://arxiv.org/abs/2506.15742},
}
@misc{flux2024,
author={Black Forest Labs},
title={FLUX},
year={2024},
howpublished={\url{https://github.com/black-forest-labs/flux}},
}
```
================================================
FILE: demo_gr.py
================================================
import os
import time
import uuid
import gradio as gr
import numpy as np
import torch
from einops import rearrange
from PIL import ExifTags, Image
from transformers import pipeline
from flux.cli import SamplingOptions
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
configs,
embed_watermark,
load_ae,
load_clip,
load_flow_model,
load_t5,
track_usage_via_api,
)
NSFW_THRESHOLD = 0.85
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
ae = load_ae(name, device="cpu" if offload else device)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier
class FluxGenerator:
def __init__(self, model_name: str, device: str, offload: bool, track_usage: bool):
self.device = torch.device(device)
self.offload = offload
self.model_name = model_name
self.is_schnell = model_name == "flux-schnell"
self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
model_name,
device=self.device,
offload=self.offload,
is_schnell=self.is_schnell,
)
self.track_usage = track_usage
@torch.inference_mode()
def generate_image(
self,
width,
height,
num_steps,
guidance,
seed,
prompt,
init_image=None,
image2image_strength=0.0,
add_sampling_metadata=True,
):
seed = int(seed)
if seed == -1:
seed = None
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if opts.seed is None:
opts.seed = torch.Generator(device="cpu").seed()
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
t0 = time.perf_counter()
if init_image is not None:
if isinstance(init_image, np.ndarray):
init_image = torch.from_numpy(init_image).permute(2, 0, 1).float() / 255.0
init_image = init_image.unsqueeze(0)
init_image = init_image.to(self.device)
init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
if self.offload:
self.ae.encoder.to(self.device)
init_image = self.ae.encode(init_image.to())
if self.offload:
self.ae = self.ae.cpu()
torch.cuda.empty_cache()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=self.device,
dtype=torch.bfloat16,
seed=opts.seed,
)
timesteps = get_schedule(
opts.num_steps,
x.shape[-1] * x.shape[-2] // 4,
shift=(not self.is_schnell),
)
if init_image is not None:
t_idx = int((1 - image2image_strength) * num_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image.to(x.dtype)
if self.offload:
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device)
inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt)
# offload TEs to CPU, load model to gpu
if self.offload:
self.t5, self.clip = self.t5.cpu(), self.clip.cpu()
torch.cuda.empty_cache()
self.model = self.model.to(self.device)
# denoise initial noise
x = denoise(self.model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if self.offload:
self.model.cpu()
torch.cuda.empty_cache()
self.ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
x = self.ae.decode(x)
if self.offload:
self.ae.decoder.cpu()
torch.cuda.empty_cache()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s.")
# bring into PIL format
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
nsfw_score = [x["score"] for x in self.nsfw_classifier(img) if x["label"] == "nsfw"][0]
if nsfw_score < NSFW_THRESHOLD:
filename = f"output/gradio/{uuid.uuid4()}.jpg"
os.makedirs(os.path.dirname(filename), exist_ok=True)
exif_data = Image.Exif()
if init_image is None:
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
else:
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = self.model_name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
img.save(filename, format="jpeg", exif=exif_data, quality=95, subsampling=0)
if self.track_usage:
track_usage_via_api(self.model_name, 1)
return img, str(opts.seed), filename, None
else:
return None, str(opts.seed), None, "Your generated image may contain NSFW content."
def create_demo(
model_name: str,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
track_usage: bool = False,
):
generator = FluxGenerator(model_name, device, offload, track_usage)
is_schnell = model_name == "flux-schnell"
with gr.Blocks() as demo:
gr.Markdown(f"# Flux Image Generation Demo - Model: {model_name}")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
value='a photo of a forest with mist swirling around the tree trunks. The word "FLUX" is painted over it in big, red brush strokes with visible texture',
)
do_img2img = gr.Checkbox(label="Image to Image", value=False, interactive=not is_schnell)
init_image = gr.Image(label="Input Image", visible=False)
image2image_strength = gr.Slider(
0.0, 1.0, 0.8, step=0.1, label="Noising strength", visible=False
)
with gr.Accordion("Advanced Options", open=False):
width = gr.Slider(128, 8192, 1360, step=16, label="Width")
height = gr.Slider(128, 8192, 768, step=16, label="Height")
num_steps = gr.Slider(1, 50, 4 if is_schnell else 50, step=1, label="Number of steps")
guidance = gr.Slider(
1.0, 10.0, 3.5, step=0.1, label="Guidance", interactive=not is_schnell
)
seed = gr.Textbox(-1, label="Seed (-1 for random)")
add_sampling_metadata = gr.Checkbox(
label="Add sampling parameters to metadata?", value=True
)
generate_btn = gr.Button("Generate")
with gr.Column():
output_image = gr.Image(label="Generated Image")
seed_output = gr.Number(label="Used Seed")
warning_text = gr.Textbox(label="Warning", visible=False)
download_btn = gr.File(label="Download full-resolution")
def update_img2img(do_img2img):
return {
init_image: gr.update(visible=do_img2img),
image2image_strength: gr.update(visible=do_img2img),
}
do_img2img.change(update_img2img, do_img2img, [init_image, image2image_strength])
generate_btn.click(
fn=generator.generate_image,
inputs=[
width,
height,
num_steps,
guidance,
seed,
prompt,
init_image,
image2image_strength,
add_sampling_metadata,
],
outputs=[output_image, seed_output, download_btn, warning_text],
)
return demo
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Flux")
parser.add_argument(
"--name", type=str, default="flux-schnell", choices=list(configs.keys()), help="Model name"
)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use"
)
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use")
parser.add_argument("--share", action="store_true", help="Create a public link to your demo")
parser.add_argument("--track_usage", action="store_true", help="Track usage for licensing purposes")
args = parser.parse_args()
demo = create_demo(args.name, args.device, args.offload, args.track_usage)
demo.launch(share=args.share)
================================================
FILE: demo_st.py
================================================
import os
import re
import time
from glob import iglob
from io import BytesIO
import streamlit as st
import torch
from einops import rearrange
from fire import Fire
from PIL import ExifTags, Image
from st_keyup import st_keyup
from torchvision import transforms
from transformers import pipeline
from flux.cli import SamplingOptions
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
configs,
embed_watermark,
load_ae,
load_clip,
load_flow_model,
load_t5,
track_usage_via_api,
)
NSFW_THRESHOLD = 0.85
@st.cache_resource()
def get_models(name: str, device: torch.device, offload: bool, is_schnell: bool):
t5 = load_t5(device, max_length=256 if is_schnell else 512)
clip = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
ae = load_ae(name, device="cpu" if offload else device)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier
def get_image() -> torch.Tensor | None:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
if image is None:
return None
image = Image.open(image).convert("RGB")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: 2.0 * x - 1.0),
]
)
img: torch.Tensor = transform(image)
return img[None, ...]
@torch.inference_mode()
def main(
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
output_dir: str = "output",
track_usage: bool = False,
):
torch_device = torch.device(device)
names = list(configs.keys())
name = st.selectbox("Which model to load?", names)
if name is None or not st.checkbox("Load model", False):
return
is_schnell = name == "flux-schnell"
model, ae, t5, clip, nsfw_classifier = get_models(
name,
device=torch_device,
offload=offload,
is_schnell=is_schnell,
)
do_img2img = (
st.checkbox(
"Image to Image",
False,
disabled=is_schnell,
help="Partially noise an image and denoise again to get variations.\n\nOnly works for flux-dev",
)
and not is_schnell
)
if do_img2img:
init_image = get_image()
if init_image is None:
st.warning("Please add an image to do image to image")
image2image_strength = st.number_input("Noising strength", min_value=0.0, max_value=1.0, value=0.8)
if init_image is not None:
h, w = init_image.shape[-2:]
st.write(f"Got image of size {w}x{h} ({h * w / 1e6:.2f}MP)")
resize_img = st.checkbox("Resize image", False) or init_image is None
else:
init_image = None
resize_img = True
image2image_strength = 0.0
# allow for packing and conversion to latent space
width = int(
16 * (st.number_input("Width", min_value=128, value=1360, step=16, disabled=not resize_img) // 16)
)
height = int(
16 * (st.number_input("Height", min_value=128, value=768, step=16, disabled=not resize_img) // 16)
)
num_steps = int(st.number_input("Number of steps", min_value=1, value=(4 if is_schnell else 50)))
guidance = float(st.number_input("Guidance", min_value=1.0, value=3.5, disabled=is_schnell))
seed_str = st.text_input("Seed", disabled=is_schnell)
if seed_str.isdecimal():
seed = int(seed_str)
else:
st.info("No seed set, set to positive integer to enable")
seed = None
save_samples = st.checkbox("Save samples?", not is_schnell)
add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)
default_prompt = (
"a photo of a forest with mist swirling around the tree trunks. The word "
'"FLUX" is painted over it in big, red brush strokes with visible texture'
)
prompt = st_keyup("Enter a prompt", value=default_prompt, debounce=300, key="interactive_text")
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
rng = torch.Generator(device="cpu")
if "seed" not in st.session_state:
st.session_state.seed = rng.seed()
def increment_counter():
st.session_state.seed += 1
def decrement_counter():
if st.session_state.seed > 0:
st.session_state.seed -= 1
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if name == "flux-schnell":
cols = st.columns([5, 1, 1, 5])
with cols[1]:
st.button("↩", on_click=increment_counter)
with cols[2]:
st.button("↪", on_click=decrement_counter)
if is_schnell or st.button("Sample"):
if is_schnell:
opts.seed = st.session_state.seed
elif opts.seed is None:
opts.seed = rng.seed()
print(f"Generating '{opts.prompt}' with seed {opts.seed}")
t0 = time.perf_counter()
if init_image is not None:
if resize_img:
init_image = torch.nn.functional.interpolate(init_image, (opts.height, opts.width))
else:
h, w = init_image.shape[-2:]
init_image = init_image[..., : 16 * (h // 16), : 16 * (w // 16)]
opts.height = init_image.shape[-2]
opts.width = init_image.shape[-1]
if offload:
ae.encoder.to(torch_device)
init_image = ae.encode(init_image.to(torch_device))
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
# divide pixel space by 16**2 to account for latent space conversion
timesteps = get_schedule(
opts.num_steps,
(x.shape[-1] * x.shape[-2]) // 4,
shift=(not is_schnell),
)
if init_image is not None:
t_idx = int((1 - image2image_strength) * num_steps)
t = timesteps[t_idx]
timesteps = timesteps[t_idx:]
x = t * x + (1.0 - t) * init_image.to(x.dtype)
if offload:
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5=t5, clip=clip, img=x, prompt=opts.prompt)
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if offload:
ae.decoder.cpu()
torch.cuda.empty_cache()
t1 = time.perf_counter()
fn = output_name.format(idx=idx)
print(f"Done in {t1 - t0:.1f}s.")
# bring into PIL format and save
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
if nsfw_score < NSFW_THRESHOLD:
buffer = BytesIO()
exif_data = Image.Exif()
if init_image is None:
exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux"
else:
exif_data[ExifTags.Base.Software] = "AI generated;img2img;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0)
img_bytes = buffer.getvalue()
if save_samples:
print(f"Saving {fn}")
with open(fn, "wb") as file:
file.write(img_bytes)
idx += 1
if track_usage:
track_usage_via_api(name, 1)
st.session_state["samples"] = {
"prompt": opts.prompt,
"img": img,
"seed": opts.seed,
"bytes": img_bytes,
}
opts.seed = None
else:
st.warning("Your generated image may contain NSFW content.")
st.session_state["samples"] = None
samples = st.session_state.get("samples", None)
if samples is not None:
st.image(samples["img"], caption=samples["prompt"])
st.download_button(
"Download full-resolution",
samples["bytes"],
file_name="generated.jpg",
mime="image/jpg",
)
st.write(f"Seed: {samples['seed']}")
def app():
Fire(main)
if __name__ == "__main__":
app()
================================================
FILE: demo_st_fill.py
================================================
import os
import re
import tempfile
import time
from glob import iglob
from io import BytesIO
import numpy as np
import streamlit as st
import torch
from einops import rearrange
from fire import Fire
from PIL import ExifTags, Image
from st_keyup import st_keyup
from streamlit_drawable_canvas import st_canvas
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
from flux.util import (
embed_watermark,
load_ae,
load_clip,
load_flow_model,
load_t5,
track_usage_via_api,
)
NSFW_THRESHOLD = 0.85
def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, zoom_up=0, zoom_down=0, overlap=0):
"""Adds a black border around the image with individual side control and mask overlap"""
orig_width, orig_height = image.size
# Calculate padding for each side (in pixels)
left_pad = int(orig_width * zoom_left)
right_pad = int(orig_width * zoom_right)
top_pad = int(orig_height * zoom_up)
bottom_pad = int(orig_height * zoom_down)
# Calculate overlap in pixels
overlap_left = int(orig_width * overlap)
overlap_right = int(orig_width * overlap)
overlap_top = int(orig_height * overlap)
overlap_bottom = int(orig_height * overlap)
# If using the all-sides zoom, add it to each side
if zoom_all > 1.0:
extra_each_side = (zoom_all - 1.0) / 2
left_pad += int(orig_width * extra_each_side)
right_pad += int(orig_width * extra_each_side)
top_pad += int(orig_height * extra_each_side)
bottom_pad += int(orig_height * extra_each_side)
# Calculate new dimensions (ensure they're multiples of 32)
new_width = 32 * round((orig_width + left_pad + right_pad) / 32)
new_height = 32 * round((orig_height + top_pad + bottom_pad) / 32)
# Create new image with black border
bordered_image = Image.new("RGB", (new_width, new_height), (0, 0, 0))
# Paste original image in position
paste_x = left_pad
paste_y = top_pad
bordered_image.paste(image, (paste_x, paste_y))
# Create mask (white where the border is, black where the original image was)
mask = Image.new("L", (new_width, new_height), 255) # White background
# Paste black rectangle with overlap adjustment
mask.paste(
0,
(
paste_x + overlap_left, # Left edge moves right
paste_y + overlap_top, # Top edge moves down
paste_x + orig_width - overlap_right, # Right edge moves left
paste_y + orig_height - overlap_bottom, # Bottom edge moves up
),
)
return bordered_image, mask
@st.cache_resource()
def get_models(name: str, device: torch.device, offload: bool):
t5 = load_t5(device, max_length=128)
clip = load_clip(device)
model = load_flow_model(name, device="cpu" if offload else device)
ae = load_ae(name, device="cpu" if offload else device)
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
return model, ae, t5, clip, nsfw_classifier
def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -> Image.Image:
width, height = img.size
mp = (width * height) / 1_000_000 # Current megapixels
if min_mp <= mp <= max_mp:
# Even if MP is in range, ensure dimensions are multiples of 32
new_width = int(32 * round(width / 32))
new_height = int(32 * round(height / 32))
if new_width != width or new_height != height:
return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
return img
# Calculate scaling factor
if mp < min_mp:
scale = (min_mp / mp) ** 0.5
else: # mp > max_mp
scale = (max_mp / mp) ** 0.5
new_width = int(32 * round(width * scale / 32))
new_height = int(32 * round(height * scale / 32))
return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
def clear_canvas_state():
"""Clear all canvas-related state"""
keys_to_clear = ["canvas", "last_image_dims"]
for key in keys_to_clear:
if key in st.session_state:
del st.session_state[key]
def set_new_image(img: Image.Image):
"""Safely set a new image and clear relevant state"""
st.session_state["current_image"] = img
clear_canvas_state()
st.rerun()
def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image:
"""Downscale image by a given factor while maintaining 32-pixel multiple dimensions"""
if scale_factor >= 1.0:
return img
width, height = img.size
new_width = int(32 * round(width * scale_factor / 32))
new_height = int(32 * round(height * scale_factor / 32))
# Ensure minimum dimensions
new_width = max(64, new_width) # minimum 64 pixels
new_height = max(64, new_height) # minimum 64 pixels
return img.resize((new_width, new_height), Image.Resampling.LANCZOS)
@torch.inference_mode()
def main(
device: str = "cuda" if torch.cuda.is_available() else "cpu",
offload: bool = False,
output_dir: str = "output",
track_usage: bool = False,
):
torch_device = torch.device(device)
st.title("Flux Fill: Inpainting & Outpainting")
# Model selection and loading
name = "flux-dev-fill"
if not st.checkbox("Load model", False):
return
try:
model, ae, t5, clip, nsfw_classifier = get_models(
name,
device=torch_device,
offload=offload,
)
except Exception as e:
st.error(f"Error loading models: {e}")
return
# Mode selection
mode = st.radio("Select Mode", ["Inpainting", "Outpainting"])
# Image handling - either from previous generation or new upload
if "input_image" in st.session_state:
image = st.session_state["input_image"]
del st.session_state["input_image"]
set_new_image(image)
st.write("Continuing from previous result")
else:
uploaded_image = st.file_uploader("Upload image", type=["jpg", "jpeg", "png"])
if uploaded_image is None:
st.warning("Please upload an image")
return
if (
"current_image_name" not in st.session_state
or st.session_state["current_image_name"] != uploaded_image.name
):
try:
image = Image.open(uploaded_image).convert("RGB")
st.session_state["current_image_name"] = uploaded_image.name
set_new_image(image)
except Exception as e:
st.error(f"Error loading image: {e}")
return
else:
image = st.session_state.get("current_image")
if image is None:
st.error("Error: Image state is invalid. Please reupload the image.")
clear_canvas_state()
return
# Add downscale control
with st.expander("Image Size Control"):
current_mp = (image.size[0] * image.size[1]) / 1_000_000
st.write(f"Current image size: {image.size[0]}x{image.size[1]} ({current_mp:.1f}MP)")
scale_factor = st.slider(
"Downscale Factor",
min_value=0.1,
max_value=1.0,
value=1.0,
step=0.1,
help="1.0 = original size, 0.5 = half size, etc.",
)
if scale_factor < 1.0 and st.button("Apply Downscaling"):
image = downscale_image(image, scale_factor)
set_new_image(image)
st.rerun()
# Resize image with validation
try:
original_mp = (image.size[0] * image.size[1]) / 1_000_000
image = resize(image)
width, height = image.size
current_mp = (width * height) / 1_000_000
if width % 32 != 0 or height % 32 != 0:
st.error("Error: Image dimensions must be multiples of 32")
return
st.write(f"Image dimensions: {width}x{height} pixels")
if original_mp != current_mp:
st.write(
f"Image has been resized from {original_mp:.1f}MP to {current_mp:.1f}MP to stay within bounds (0.5MP - 2MP)"
)
except Exception as e:
st.error(f"Error processing image: {e}")
return
if mode == "Outpainting":
# Outpainting controls
zoom_all = st.slider("Zoom Out Amount (All Sides)", min_value=1.0, max_value=3.0, value=1.0, step=0.1)
with st.expander("Advanced Zoom Controls"):
st.info("These controls add additional zoom to specific sides")
col1, col2 = st.columns(2)
with col1:
zoom_left = st.slider("Left", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
zoom_right = st.slider("Right", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
with col2:
zoom_up = st.slider("Up", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
zoom_down = st.slider("Down", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
overlap = st.slider("Overlap", min_value=0.01, max_value=0.25, value=0.01, step=0.01)
# Generate bordered image and mask
image_for_generation, mask = add_border_and_mask(
image,
zoom_all=zoom_all,
zoom_left=zoom_left,
zoom_right=zoom_right,
zoom_up=zoom_up,
zoom_down=zoom_down,
overlap=overlap,
)
width, height = image_for_generation.size
# Show preview
col1, col2 = st.columns(2)
with col1:
st.image(image_for_generation, caption="Image with Border")
with col2:
st.image(mask, caption="Mask (white areas will be generated)")
else: # Inpainting mode
# Canvas setup with dimension tracking
canvas_key = f"canvas_{width}_{height}"
if "last_image_dims" not in st.session_state:
st.session_state.last_image_dims = (width, height)
elif st.session_state.last_image_dims != (width, height):
clear_canvas_state()
st.session_state.last_image_dims = (width, height)
st.rerun()
try:
canvas_result = st_canvas(
fill_color="rgba(255, 255, 255, 0.0)",
stroke_width=st.slider("Brush size", 1, 500, 50),
stroke_color="#fff",
background_image=image,
height=height,
width=width,
drawing_mode="freedraw",
key=canvas_key,
display_toolbar=True,
)
except Exception as e:
st.error(f"Error creating canvas: {e}")
clear_canvas_state()
st.rerun()
return
# Sampling parameters
num_steps = int(st.number_input("Number of steps", min_value=1, value=50))
guidance = float(st.number_input("Guidance", min_value=1.0, value=30.0))
seed_str = st.text_input("Seed")
if seed_str.isdecimal():
seed = int(seed_str)
else:
st.info("No seed set, using random seed")
seed = None
save_samples = st.checkbox("Save samples?", True)
add_sampling_metadata = st.checkbox("Add sampling parameters to metadata?", True)
# Prompt input
prompt = st_keyup("Enter a prompt", value="", debounce=300, key="interactive_text")
# Setup output path
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
idx = len(fns)
if st.button("Generate"):
valid_input = False
if mode == "Inpainting" and canvas_result.image_data is not None:
valid_input = True
# Create mask from canvas
try:
mask = Image.fromarray(canvas_result.image_data)
mask = mask.getchannel("A") # Get alpha channel
mask_array = np.array(mask)
mask_array = (mask_array > 0).astype(np.uint8) * 255
mask = Image.fromarray(mask_array)
image_for_generation = image
except Exception as e:
st.error(f"Error creating mask: {e}")
return
elif mode == "Outpainting":
valid_input = True
# image_for_generation and mask are already set above
if not valid_input:
st.error("Please draw a mask or configure outpainting settings")
return
# Create temporary files
with (
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img,
tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_mask,
):
try:
image_for_generation.save(tmp_img.name)
mask.save(tmp_mask.name)
except Exception as e:
st.error(f"Error saving temporary files: {e}")
return
try:
# Generate inpainting/outpainting
rng = torch.Generator(device="cpu")
if seed is None:
seed = rng.seed()
print(f"Generating with seed {seed}:\n{prompt}")
t0 = time.perf_counter()
x = get_noise(
1,
height,
width,
device=torch_device,
dtype=torch.bfloat16,
seed=seed,
)
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp = prepare_fill(
t5,
clip,
x,
prompt=prompt,
ae=ae,
img_cond_path=tmp_img.name,
mask_path=tmp_mask.name,
)
timesteps = get_schedule(num_steps, inp["img"].shape[1], shift=True)
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
x = denoise(model, **inp, timesteps=timesteps, guidance=guidance)
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
x = unpack(x.float(), height, width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
# Process and display result
x = x.clamp(-1, 1)
x = embed_watermark(x.float())
x = rearrange(x[0], "c h w -> h w c")
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
nsfw_score = [x["score"] for x in nsfw_classifier(img) if x["label"] == "nsfw"][0]
if nsfw_score < NSFW_THRESHOLD:
buffer = BytesIO()
exif_data = Image.Exif()
exif_data[ExifTags.Base.Software] = "AI generated;inpainting;flux"
exif_data[ExifTags.Base.Make] = "Black Forest Labs"
exif_data[ExifTags.Base.Model] = name
if add_sampling_metadata:
exif_data[ExifTags.Base.ImageDescription] = prompt
img.save(buffer, format="jpeg", exif=exif_data, quality=95, subsampling=0)
img_bytes = buffer.getvalue()
if save_samples:
fn = output_name.format(idx=idx)
print(f"Saving {fn}")
with open(fn, "wb") as file:
file.write(img_bytes)
if track_usage:
track_usage_via_api(name, 1)
st.session_state["samples"] = {
"prompt": prompt,
"img": img,
"seed": seed,
"bytes": img_bytes,
}
else:
st.warning("Your generated image may contain NSFW content.")
st.session_state["samples"] = None
except Exception as e:
st.error(f"Error during generation: {e}")
return
finally:
# Clean up temporary files
try:
os.unlink(tmp_img.name)
os.unlink(tmp_mask.name)
except Exception as e:
print(f"Error cleaning up temporary files: {e}")
# Display results
samples = st.session_state.get("samples", None)
if samples is not None:
st.image(samples["img"], caption=samples["prompt"])
col1, col2 = st.columns(2)
with col1:
st.download_button(
"Download full-resolution",
samples["bytes"],
file_name="generated.jpg",
mime="image/jpg",
)
with col2:
if st.button("Continue from this image"):
# Store the generated image
new_image = samples["img"]
# Clear ALL canvas state
clear_canvas_state()
if "samples" in st.session_state:
del st.session_state["samples"]
# Set as current image
st.session_state["current_image"] = new_image
st.rerun()
st.write(f"Seed: {samples['seed']}")
def app():
Fire(main)
if __name__ == "__main__":
st.set_page_config(layout="wide")
app()
================================================
FILE: docs/fill.md
================================================
## Open-weight models
FLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless edits that integrate naturally with existing images.
| Name | HuggingFace repo | License | sha256sum |
| ------------------- | -------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Fill [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Fill-dev | [FLUX.1-dev Non-Commercial License](model_licenses/LICENSE-FLUX1-dev) | 03e289f530df51d014f48e675a9ffa2141bc003259bf5f25d75b957e920a41ca |
## Examples


## Open-weights usage
The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
```
For interactive sampling run
```bash
python -m flux fill --loop
```
Or to generate a single sample run
```bash
python -m flux fill \
--img_cond_path <path_to_input_image> \
--img_mask_path <path_to_input_mask>
```
The input_mask should be an image of the same size as the conditioning image that only contains black and white pixels; see [an example mask](../assets/cup_mask.png) for [this image](../assets/cup.png).
We also provide an interactive streamlit demo. The demo can be run via
```bash
streamlit run demo_st_fill.py
```
================================================
FILE: docs/image-editing.md
================================================
## Open-weight models
We currently offer two open-weight text-to-image models.
| Name | HuggingFace repo | License | sha256sum |
| ------------------------- | ----------------------------------------------------------------| --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Kontext [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 843a26dc765d3105dba081c30bce7b14c65b0988f9e8d14e9fbc8856a6deebd5 |
## Examples
![FLUX.1 [dev] Grid](../assets/docs/kontext.png)
## Open-weights usage
The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
```
For interactive sampling run
```bash
python -m flux kontext --loop
```
Or to generate a single sample run
```bash
python -m flux kontext \
--img_cond_path <path_to_input_image> \
--prompt <your_prompt> \
--num_steps 30 --aspect_ratio "16:9" --guidance 2.5 --seed 1
```
Note that the flags `num_steps`, `aspect_ratio`, `guidance` and `seed` are
optional. For more available flags see [the code](../src/flux/cli_kontext.py).
### TRT engine infernece
We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).
```bash
python -m flux kontext --loop --trt --trt_transformer_precision <precision>
```
where `<trt_transformer_precision>` is either `bf16`, `fp8`, or `fp4_sdvd32`.
================================================
FILE: docs/image-variation.md
================================================
## Models
FLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which can be used to generate image variations.
| Name | HuggingFace repo | License | sha256sum |
| --------------------------- | ----------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Redux [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Redux-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | a1b3bdcb4bdc58ce04874b9ca776d61fc3e914bb6beab41efb63e4e2694dca45 |
## Examples

## Open-weights usage
The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_REDUX=<your redux path here>
export FLUX_AE=<your autoencoder path here>
```
For interactive sampling run:
```bash
python -m flux redux --name <name> --loop
```
where `name` specifies the base model, which should be one of `flux-dev` or `flux-schnell`.
================================================
FILE: docs/structural-conditioning.md
================================================
## Models
Structural conditioning uses canny edge or depth detection to maintain precise control during image transformations. By preserving the original image's structure through edge or depth maps, users can make text-guided edits while keeping the core composition intact. This is particularly effective for retexturing images. We release four variations: two based on edge maps (full model and LoRA for FLUX.1 [dev]) and two based on depth maps (full model and LoRA for FLUX.1 [dev]).
| Name | HuggingFace repo | License | sha256sum |
| ------------------------- | -------------------------------------------------------------- | --------------------------------------------------------------------- | ---------------------------------------------------------------- |
| `FLUX.1 Canny [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 996876670169591cb412b937fbd46ea14cbed6933aef17c48a2dcd9685c98cdb |
| `FLUX.1 Depth [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 41360d1662f44ca45bc1b665fe6387e91802f53911001630d970a4f8be8dac21 |
| `FLUX.1 Canny [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Canny-dev-lora | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 8eaa21b9c43d5e7242844deb64b8cf22ae9010f813f955ca8c05f240b8a98f7e |
| `FLUX.1 Depth [dev] LoRA` | https://huggingface.co/black-forest-labs/FLUX.1-Depth-dev-lora | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 1938b38ea0fdd98080fa3e48beb2bedfbc7ad102d8b65e6614de704a46d8b907 |
## Examples


## Open-weights usage
The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
# optional (see below)
export FLUX_LORA=<your lora path here>
```
Note that the LoRA models (`flux-dev-canny-lora` and `flux-dev-depth-lora`) require the base FLUX.1 [dev] model to be downloaded first. The system will automatically download both the base model and the LoRA adapter when using these variants.
For interactive sampling run
```bash
python -m flux control --name <name> --loop
```
where `name` is one of `flux-dev-canny`, `flux-dev-depth`, `flux-dev-canny-lora`, or `flux-dev-depth-lora`.
### TRT engine inference
We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).
```bash
python flux control --name=<name> --loop --img_cond_path="assets/robot.webp" --trt --static_shape=False --trt_transformer_precision <precision>
```
where `<precision>` is either `bf16`, `fp8`, or `fp4`.
## Diffusers usage
Flux Control (including the LoRAs) is also compatible with the `diffusers` Python library. Check out the [documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) to learn more.
================================================
FILE: docs/text-to-image.md
================================================
## Open-weight models
We currently offer two open-weight text-to-image models.
| Name | HuggingFace repo | License | sha256sum |
| --------------------------|----------------------------------------------------------| -------------------------------------------------------------------------|----------------------------------------------------------------- |
| `FLUX.1 [schnell]` | https://huggingface.co/black-forest-labs/FLUX.1-schnell | [apache-2.0](../model_licenses/LICENSE-FLUX1-schnell) | 9403429e0052277ac2a87ad800adece5481eecefd9ed334e1f348723621d2a0a |
| `FLUX.1 [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 |
| `FLUX.1 Krea [dev]` | https://huggingface.co/black-forest-labs/FLUX.1-Krea-dev | [FLUX.1-dev Non-Commercial License](../model_licenses/LICENSE-FLUX1-dev) | 4610115bb0c89560703c892c59ac2742fa821e60ef5871b33493ba544683abd7 |
## Open-weights usage
The weights will be downloaded automatically to `checkpoints/` from HuggingFace once you start one of the demos. Alternatively, you may download the weights manually and put them in `checkpoints/`, or you can also manually link them with the following environment variables:
```bash
export FLUX_MODEL=<your model path here>
export FLUX_AE=<your autoencoder path here>
```
For interactive sampling run
```bash
python -m flux t2i --name <name> --loop
```
where `name` is one of `flux-dev` or `flux-schnell`. Or to generate a single sample run
```bash
python -m flux t2i --name <name> \
--height <height> --width <width> \
--prompt "<prompt>"
```
### TRT engine infernece
We provide exports in BF16, FP8, and FP4 precision. Note that you need to install the repository with TensorRT support as outlined [here](../README.md).
```bash
python -m flux t2i --name=<name> --loop --trt --trt_transformer_precision <precision>
```
where `<trt_transformer_precision>` is either `bf16`, `fp8`, or `fp4`. For ONNX exports, `height` and `width` have to be within 768 and 1344.
### Streamlit and Gradio
We also provide a streamlit demo that does both text-to-image and image-to-image. The demo can be run via
```bash
streamlit run demo_st.py
```
We also offer a Gradio-based demo for an interactive experience. To run the Gradio demo:
```bash
python demo_gr.py --name flux-schnell --device cuda
```
Options:
- `--name`: Choose the model to use (options: "flux-schnell", "flux-dev")
- `--device`: Specify the device to use (default: "cuda" if available, otherwise "cpu")
- `--offload`: Offload model to CPU when not in use
- `--share`: Create a public link to your demo
To run the demo with the dev model and create a public link:
```bash
python demo_gr.py --name flux-dev --share
```
## Diffusers integration
`FLUX.1 [schnell]` and `FLUX.1 [dev]` are integrated with the [🧨 diffusers](https://github.com/huggingface/diffusers) library. To use it with diffusers, install it:
```shell
pip install git+https://github.com/huggingface/diffusers.git
```
Then you can use `FluxPipeline` to run the model
```python
import torch
from diffusers import FluxPipeline
model_id = "black-forest-labs/FLUX.1-schnell" #you can also use `black-forest-labs/FLUX.1-dev`
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=4, #use a larger number if you are using [dev]
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
image.save("flux-schnell.png")
```
To learn more check out the [diffusers](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux) documentation
================================================
FILE: model_cards/FLUX.1-Krea-dev.md
================================================
![FLUX.1 Krea [dev] Grid](../assets/flux-1-krea-dev-grid.png)
`FLUX.1 Krea [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our [blog post](https://bfl.ai/announcements/flux-1-krea-dev).
# Key Features
1. Cutting-edge output quality, with a focus on aesthetic photography.
2. Competitive prompt following, matching the performance of closed source alternatives.
3. Trained using guidance distillation, making `FLUX.1 Krea [dev]` more efficient.
4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).
# Usage
`FLUX.1 Krea [dev]` can be used as a drop-in replacement in every system that supports the original `FLUX.1 [dev]`.
A reference implementation of `FLUX.1 [dev]` is in our dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point.
`FLUX.1 Krea [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers).
---
# Limitations
- This model is not intended or able to provide factual information.
- As a statistical model this checkpoint might amplify existing societal biases.
- The model may fail to generate output that matches the prompts.
- Prompt following is heavily influenced by the prompting-style.
---
# Out-of-Scope Use
The model and its derivatives may not be used
- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.
- Please reference our [content filters](https://github.com/black-forest-labs/flux/blob/main/src/flux/content_filters.py) to avoid such generations.
---
# Risks
Black Forest Labs (BFL) and Krea are committed to the responsible development of generative AI technology. Prior to releasing FLUX.1 Krea [dev], BFL and Krea collaboratively evaluated and mitigated a number of risks in the FLUX.1 Krea [dev] model and services, including the generation of unlawful content. We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks:
1. **Pre-training mitigation.** BFL filtered pre-training data for multiple categories of “not safe for work” (NSFW) and unlawful content to help prevent a user generating unlawful content in response to text prompts or uploaded images.
2. **Post-training mitigation.** BFL has partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known child sexual abuse material (CSAM) from post-training data. Subsequently, BFL and Krea undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse. By inhibiting certain behaviors and concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or nonconsensual intimate imagery (NCII) from a text prompt.
3. **Pre-release evaluation.** Throughout this process, BFL conducted internal and external third-party evaluations of model checkpoints to identify further opportunities for improvement. The third-party evaluations focused on eliciting CSAM and NCII through adversarial testing of the text-to-image model with text-only prompts. We also conducted internal evaluations of the proposed release checkpoints, comparing the model with other leading openly-available generative image models from other companies. The final FLUX.1 Krea [dev] open-weight model checkpoint demonstrated very high resilience against violative inputs, demonstrating higher resilience than other similar open-weight models across these risk categories. Based on these findings, we approved the release of the FLUX.1 Krea [dev] model as openly-available weights under a non-commercial license to support third-party research and development.
4. **Inference filters.** The BFL Github repository for the open FLUX.1 Krea [dev] model includes filters for illegal or infringing content. Filters or manual review must be used with the model under the terms of the FLUX.1 [dev] Non-Commercial License. We may approach known deployers of the FLUX.1 Krea [dev] model at random to verify that filters or manual review processes are in place.
5. **Policies.** Our FLUX.1 [dev] Non-Commercial License prohibits the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX.1 Krea [dev] model.
6. **Monitoring.** BFL is monitoring for patterns of violative use after release, and may ban developers who we detect intentionally and repeatedly violate our policies. Additionally, BFL provides a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. BFL maintains a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and BFL welcomes ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations.
---
# License
This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
================================================
FILE: model_cards/FLUX.1-dev.md
================================================
![FLUX.1 [dev] Grid](../assets/dev_grid.jpg)
`FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
# Key Features
1. Cutting-edge output quality, second only to our state-of-the-art model `FLUX.1 [pro]`.
2. Competitive prompt following, matching the performance of closed source alternatives.
3. Trained using guidance distillation, making `FLUX.1 [dev]` more efficient.
4. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
5. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [flux-1-dev-non-commercial-license](./licence.md).
# Usage
We provide a reference implementation of `FLUX.1 [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 [dev]` are encouraged to use this as a starting point.
## API Endpoints
The FLUX.1 models are also available via API from the following sources
1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
2. [replicate.com](https://replicate.com/collections/flux)
3. [fal.ai](https://fal.ai/models/fal-ai/flux/dev)
## ComfyUI
`FLUX.1 [dev]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.
---
# Limitations
- This model is not intended or able to provide factual information.
- As a statistical model this checkpoint might amplify existing societal biases.
- The model may fail to generate output that matches the prompts.
- Prompt following is heavily influenced by the prompting-style.
# Out-of-Scope Use
The model and its derivatives may not be used
- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.
# License
This model falls under the [`FLUX.1 [dev]` Non-Commercial License](https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/LICENSE.md).
================================================
FILE: model_cards/FLUX.1-kontext-dev.md
================================================
![FLUX.1 [dev] Grid](../assets/docs/kontext.png)
`FLUX.1 Kontext [dev]` is a 12 billion parameter rectified flow transformer capable of editing images based on text instructions.
For more information, please read our [blog post](https://bfl.ai/announcements/flux-1-kontext-dev) and our [technical report](https://arxiv.org/abs/2506.15742).
You can find information about the `[pro]` version in [here](https://bfl.ai/models/flux-kontext).
# Key Features
1. Change existing images based on an edit instruction.
2. Have character, style and object reference without any finetuning.
3. Robust consistency allows users to refine an image through multiple successive edits with minimal visual drift.
4. Trained using guidance distillation, making `FLUX.1 Kontext [dev]` more efficient.
5. Open weights to drive new scientific research, and empower artists to develop innovative workflows.
6. Generated outputs can be used for personal, scientific, and commercial purposes, as described in the [FLUX.1 \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).
# Usage
We provide a reference implementation of `FLUX.1 Kontext [dev]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 Kontext [dev]` are encouraged to use this as a starting point.
`FLUX.1 Kontext [dev]` is also available in both [ComfyUI](https://github.com/comfyanonymous/ComfyUI) and [Diffusers](https://github.com/huggingface/diffusers).
## API Endpoints
The FLUX.1 Kontext models are also available via API from the following sources
- bfl.ai: https://docs.bfl.ai/
- DataCrunch: https://datacrunch.io/flux-kontext
- fal: https://fal.ai/flux-kontext
- Replicate: https://replicate.com/blog/flux-kontext
- https://replicate.com/black-forest-labs/flux-kontext-dev
- https://replicate.com/black-forest-labs/flux-kontext-pro
- https://replicate.com/black-forest-labs/flux-kontext-max
- Runware: https://runware.ai/blog/introducing-flux1-kontext-instruction-based-image-editing-with-ai?utm_source=bfl
- TogetherAI: https://www.together.ai/models/flux-1-kontext-dev
---
# Risks
Risks
Black Forest Labs is committed to the responsible development of generative AI technology. Prior to releasing FLUX.1 Kontext, we evaluated and mitigated a number of risks in our models and services, including the generation of unlawful content. We implemented a series of pre-release mitigations to help prevent misuse by third parties, with additional post-release mitigations to help address residual risks:
1. **Pre-training mitigation**. We filtered pre-training data for multiple categories of “not safe for work” (NSFW) content to help prevent a user generating unlawful content in response to text prompts or uploaded images.
2. **Post-training mitigation.** We have partnered with the Internet Watch Foundation, an independent nonprofit organization dedicated to preventing online abuse, to filter known child sexual abuse material (CSAM) from post-training data. Subsequently, we undertook multiple rounds of targeted fine-tuning to provide additional mitigation against potential abuse. By inhibiting certain behaviors and concepts in the trained model, these techniques can help to prevent a user generating synthetic CSAM or nonconsensual intimate imagery (NCII) from a text prompt, or transforming an uploaded image into synthetic CSAM or NCII.
3. **Pre-release evaluation.** Throughout this process, we conducted multiple internal and external third-party evaluations of model checkpoints to identify further opportunities for improvement. The third-party evaluations—which included 21 checkpoints of FLUX.1 Kontext [pro] and [dev]—focused on eliciting CSAM and NCII through adversarial testing with text-only prompts, as well as uploaded images with text prompts. Next, we conducted a final third-party evaluation of the proposed release checkpoints, focused on text-to-image and image-to-image CSAM and NCII generation. The final FLUX.1 Kontext [pro] (as offered through the FLUX API only) and FLUX.1 Kontext [dev] (released as an open-weight model) checkpoints demonstrated very high resilience against violative inputs, and FLUX.1 Kontext [dev] demonstrated higher resilience than other similar open-weight models across these risk categories. Based on these findings, we approved the release of the FLUX.1 Kontext [pro] model via API, and the release of the FLUX.1 Kontext [dev] model as openly-available weights under a non-commercial license to support third-party research and development.
4. **Inference filters.** We are applying multiple filters to intercept text prompts, uploaded images, and output images on the FLUX API for FLUX.1 Kontext [pro]. Filters for CSAM and NCII are provided by Hive, a third-party provider, and cannot be adjusted or removed by developers. We provide filters for other categories of potentially harmful content, including gore, which can be adjusted by developers based on their specific risk profile. Additionally, the repository for the open FLUX.1 Kontext [dev] model includes filters for illegal or infringing content. Filters or manual review must be used with the model under the terms of the FLUX.1 [dev] Non-Commercial License. We may approach known deployers of the FLUX.1 Kontext [dev] model at random to verify that filters or manual review processes are in place.
5. **Content provenance.** The FLUX API applies cryptographically-signed metadata to output content to indicate that images were produced with our model. Our API implements the Coalition for Content Provenance and Authenticity (C2PA) standard for metadata.
6. **Policies.** Access to our API and use of our models are governed by our Developer Terms of Service, Usage Policy, and FLUX.1 [dev] Non-Commercial License, which prohibit the generation of unlawful content or the use of generated content for unlawful, defamatory, or abusive purposes. Developers and users must consent to these conditions to access the FLUX Kontext models.
7. **Monitoring.** We are monitoring for patterns of violative use after release, and may ban developers who we detect intentionally and repeatedly violate our policies via the FLUX API. Additionally, we provide a dedicated email address (safety@blackforestlabs.ai) to solicit feedback from the community. We maintain a reporting relationship with organizations such as the Internet Watch Foundation and the National Center for Missing and Exploited Children, and we welcome ongoing engagement with authorities, developers, and researchers to share intelligence about emerging risks and develop effective mitigations.
# License
This model falls under the [FLUX.1 \[dev\] Non-Commercial License](https://github.com/black-forest-labs/flux/blob/main/model_licenses/LICENSE-FLUX1-dev).
# Citation
```bib
@misc{labs2025flux1kontextflowmatching,
title={FLUX.1 Kontext: Flow Matching for In-Context Image Generation and Editing in Latent Space}, Add commentMore actions
author={Black Forest Labs and Stephen Batifol and Andreas Blattmann and Frederic Boesel and Saksham Consul and Cyril Diagne and Tim Dockhorn and Jack English and Zion English and Patrick Esser and Sumith Kulal and Kyle Lacey and Yam Levi and Cheng Li and Dominik Lorenz and Jonas Müller and Dustin Podell and Robin Rombach and Harry Saini and Axel Sauer and Luke Smith},
year={2025},
eprint={2506.15742},
archivePrefix={arXiv},
primaryClass={cs.GR},
url={https://arxiv.org/abs/2506.15742},
}
```
================================================
FILE: model_cards/FLUX.1-schnell.md
================================================
![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg)
`FLUX.1 [schnell]` is a 12 billion parameter rectified flow transformer capable of generating images from text descriptions.
For more information, please read our [blog post](https://blackforestlabs.ai/announcing-black-forest-labs/).
# Key Features
1. Cutting-edge output quality and competitive prompt following, matching the performance of closed source alternatives.
2. Trained using latent adversarial diffusion distillation, `FLUX.1 [schnell]` can generate high-quality images in only 1 to 4 steps.
3. Released under the `apache-2.0` licence, the model can be used for personal, scientific, and commercial purposes.
# Usage
We provide a reference implementation of `FLUX.1 [schnell]`, as well as sampling code, in a dedicated [github repository](https://github.com/black-forest-labs/flux).
Developers and creatives looking to build on top of `FLUX.1 [schnell]` are encouraged to use this as a starting point.
## API Endpoints
The FLUX.1 models are also available via API from the following sources
1. [bfl.ml](https://docs.bfl.ml/) (currently `FLUX.1 [pro]`)
2. [replicate.com](https://replicate.com/collections/flux)
3. [fal.ai](https://fal.ai/models/fal-ai/flux/schnell)
## ComfyUI
`FLUX.1 [schnell]` is also available in [Comfy UI](https://github.com/comfyanonymous/ComfyUI) for local inference with a node-based workflow.
---
# Limitations
- This model is not intended or able to provide factual information.
- As a statistical model this checkpoint might amplify existing societal biases.
- The model may fail to generate output that matches the prompts.
- Prompt following is heavily influenced by the prompting-style.
# Out-of-Scope Use
The model and its derivatives may not be used
- In any way that violates any applicable national, federal, state, local or international law or regulation.
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; including but not limited to the solicitation, creation, acquisition, or dissemination of child exploitative content.
- To generate or disseminate verifiably false information and/or content with the purpose of harming others.
- To generate or disseminate personal identifiable information that can be used to harm an individual.
- To harass, abuse, threaten, stalk, or bully individuals or groups of individuals.
- To create non-consensual nudity or illegal pornographic content.
- For fully automated decision making that adversely impacts an individual's legal rights or otherwise creates or modifies a binding, enforceable obligation.
- Generating or facilitating large-scale disinformation campaigns.
================================================
FILE: model_licenses/LICENSE-FLUX1-dev
================================================
FLUX.1 [dev] Non-Commercial License v1.1.1
Black Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make available the weights, parameters and inference code for the FLUX.1 [dev] Model (as defined below) freely available for your non-commercial and non-production use as set forth in this FLUX.1 [dev] Non-Commercial License (“License”). The “FLUX.1 [dev] Model” means the FLUX.1 [dev] AI models and models denoted as FLUX.1 [dev], including but not limited to FLUX.1 [dev], FLUX.1 Fill [dev], FLUX.1 Depth [dev], FLUX.1 Canny [dev], FLUX.1 Redux [dev], FLUX.1 Canny [dev] LoRA, FLUX.1 Depth [dev] LoRA, and FLUX.1 Kontext [dev], and their elements which includes algorithms, software, checkpoints, parameters, source code (inference code, evaluation code, and if applicable, fine-tuning code) and any other materials associated with the FLUX.1 [dev] AI models made available by Company under this License, including if any, the technical documentation, manuals and instructions for the use and operation thereof (collectively, “FLUX.1 [dev] Model”). Note that we may also make available certain elements of what is included in the definition of “FLUX.1 [dev] Model” under a separate license, such as the inference code, and nothing in this License will be deemed to restrict or limit any other licenses granted by us in such elements.
By downloading, accessing, using, Distributing (as defined below), or creating a Derivative (as defined below) of the FLUX.1 [dev] Model, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to access, use, Distribute or create a Derivative of the FLUX.1 [dev] Model and you must immediately cease using the FLUX.1 [dev] Model. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to us that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the FLUX.1 [dev] Model on behalf of your employer or other entity.
1. Definitions.
- a. “Derivative” means any (i) modified version of the FLUX.1 [dev] Model (including but not limited to any customized or fine-tuned version thereof), (ii) work based on the FLUX.1 [dev] Model, or (iii) any other derivative work thereof. For the avoidance of doubt, Outputs are not considered Derivatives under this License.
- b. “Distribution” or “Distribute” or “Distributing” means providing or making available, by any means, a copy of the FLUX.1 [dev] Models and/or the Derivatives as the case may be.
- c. “Non-Commercial Purpose” means any of the following uses, but only so far as you do not receive any direct or indirect payment arising from the use of the FLUX.1 [dev] Model, Derivatives, or FLUX Content Filters (as defined below): (i) personal use for research, experiment, and testing for the benefit of public knowledge, personal study, private entertainment, hobby projects, or otherwise not directly or indirectly connected to any commercial activities, business operations, or employment responsibilities; (ii) use by commercial or for-profit entities for testing, evaluation, or non-commercial research and development in a non-production environment; and (iii) use by any charitable organization for charitable purposes, or for testing or evaluation. For clarity, use (a) for revenue-generating activity, (b) in direct interactions with or that has impact on end users, or (c) to train, fine tune or distill other models for commercial use, in each case is not a Non-Commercial Purpose.
- d. “Outputs” means any content generated by the operation of the FLUX.1 [dev] Models or the Derivatives from an input (such as an image input) or prompt (i.e., text instructions) provided by users. For the avoidance of doubt, Outputs do not include any components of the FLUX.1 [dev] Models, such as any fine-tuned versions of the FLUX.1 [dev] Models, the weights, or parameters.
- e. “you” or “your” means the individual or entity entering into this License with Company.
2. License Grant.
- a. License. Subject to your compliance with this License, Company grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty free and limited license to access, use, create Derivatives of, and Distribute the FLUX.1 [dev] Models and Derivatives solely for your Non-Commercial Purposes. The foregoing license is personal to you, and you may not assign or sublicense this License or any other rights or obligations under this License without Company’s prior written consent; any such assignment or sublicense will be void and will automatically and immediately terminate this License. Any restrictions set forth herein regarding the FLUX.1 [dev] Model also apply to any Derivative you create or that are created on your behalf.
- b. Non-Commercial Use Only. You may only access, use, Distribute, or create Derivatives of the FLUX.1 [dev] Model or Derivatives for Non-Commercial Purposes. If you want to use a FLUX.1 [dev] Model or a Derivative for any purpose that is not expressly authorized under this License, such as for a commercial activity, you must request a license from Company, which Company may grant to you in Company’s sole discretion and which additional use may be subject to a fee, royalty or other revenue share. Please see www.bfl.ai if you would like a commercial license.
- c. Reserved Rights. The grant of rights expressly set forth in this License are the complete grant of rights to you in the FLUX.1 [dev] Model, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. Company and its licensors reserve all rights not expressly granted by this License.
- d. Outputs. We claim no ownership rights in and to the Outputs. You are solely responsible for the Outputs you generate and their subsequent uses in accordance with this License. You may use Output for any purpose (including for commercial purposes), except as expressly prohibited herein. You may not use the Output to train, fine-tune or distill a model that is competitive with the FLUX.1 [dev] Model or the FLUX.1 Kontext [dev] Model.
- e. You may access, use, Distribute, or create Output of the FLUX.1 [dev] Model or Derivatives if you: (i) (A) implement and maintain content filtering measures (“Content Filters”) for your use of the FLUX.1 [dev] Model or Derivatives to prevent the creation, display, transmission, generation, or dissemination of unlawful or infringing content, which may include Content Filters that we may make available for use with the FLUX.1 [dev] Model (“FLUX Content Filters”), or (B) ensure Output undergoes review for unlawful or infringing content before public or non-public distribution, display, transmission or dissemination; and (ii) ensure Output includes disclosure (or other indication) that the Output was generated or modified using artificial intelligence technologies to the extent required under applicable law.
3. Distribution. Subject to this License, you may Distribute copies of the FLUX.1 [dev] Model and/or Derivatives made by you, under the following conditions:
- a. you must make available a copy of this License to third-party recipients of the FLUX.1 [dev] Models and/or Derivatives you Distribute, and specify that any rights to use the FLUX.1 [dev] Models and/or Derivatives shall be directly granted by Company to said third-party recipients pursuant to this License;
- b. you must prominently display the following notice alongside the Distribution of the FLUX.1 [dev] Model or Derivative (such as via a “Notice” text file distributed as part of such FLUX.1 [dev] Model or Derivative) (the “Attribution Notice”):
“The FLUX.1 [dev] Model is licensed by Black Forest Labs Inc. under the FLUX.1 [dev] Non-Commercial License. Copyright Black Forest Labs Inc.
IN NO EVENT SHALL BLACK FOREST LABS INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.”
- c. in the case of Distribution of Derivatives made by you: (i) you must also include in the Attribution Notice a statement that you have modified the applicable FLUX.1 [dev] Model; (ii) any terms and conditions you impose on any third-party recipients relating to Derivatives made by or for you shall neither limit such third-party recipients’ use of the FLUX.1 [dev] Model or any Derivatives made by or for Company in accordance with this License nor conflict with any of its terms and conditions and must include disclaimer of warranties and limitation of liability provisions that are at least as protective of Company as those set forth herein; and (iii) you must not misrepresent or imply, through any means, that the Derivatives made by or for you and/or any modified version of the FLUX.1 [dev] Model you Distribute under your name and responsibility is an official product of the Company or has been endorsed, approved or validated by the Company, unless you are authorized by Company to do so in writing.
4. Restrictions. You will not, and will not permit, assist or cause any third party to
- a. use, modify, copy, reproduce, create Derivatives of, or Distribute the FLUX.1 [dev] Model (or any Derivative thereof, or any data produced by the FLUX.1 [dev] Model), in whole or in part, (i) for any commercial or production purposes, (ii) military purposes, (iii) purposes of surveillance, including any research or development relating to surveillance, (iv) biometric processing, (v) in any manner that infringes, misappropriates, or otherwise violates (or is likely to infringe, misappropriate, or otherwise violate) any third party’s legal rights, including rights of publicity or “digital replica” rights, (vi) in any unlawful, fraudulent, defamatory, or abusive activity, (vii) to generate unlawful content, including child sexual abuse material, or non-consensual intimate images; or (viii) in any manner that violates any applicable law and violating any privacy or security laws, rules, regulations, directives, or governmental requirements (including the General Data Privacy Regulation (Regulation (EU) 2016/679), the California Consumer Privacy Act, any and all laws governing the processing of biometric information, and the EU Artificial Intelligence Act (Regulation (EU) 2024/1689), as well as all amendments and successor laws to any of the foregoing;
- b. alter or remove copyright and other proprietary notices which appear on or in any portion of the FLUX.1 [dev] Model;
- c. utilize any equipment, device, software, or other means to circumvent or remove any security or protection used by Company in connection with the FLUX.1 [dev] Model, or to circumvent or remove any usage restrictions, or to enable functionality disabled by FLUX.1 [dev] Model;
- d. offer or impose any terms on the FLUX.1 [dev] Model that alter, restrict, or are inconsistent with the terms of this License;
- e. violate any applicable U.S. and non-U.S. export control and trade sanctions laws (“Export Laws”) in connection with your use or Distribution of any FLUX.1 [dev] Model;
- f. directly or indirectly Distribute, export, or otherwise transfer FLUX.1 [dev] Model (i) to any individual, entity, or country prohibited by Export Laws; (ii) to anyone on U.S. or non-U.S. government restricted parties lists; (iii) for any purpose prohibited by Export Laws, including nuclear, chemical or biological weapons, or missile technology applications; (iv) use or download FLUX.1 [dev] Model if you or they are (a) located in a comprehensively sanctioned jurisdiction, (b) currently listed on any U.S. or non-U.S. restricted parties list, or (c) for any purpose prohibited by Export Laws; and (v) will not disguise your location through IP proxying or other methods.
5. DISCLAIMERS. THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS ARE PROVIDED “AS IS” AND “WITH ALL FAULTS” WITH NO WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY STATUTE, CUSTOM, USAGE OR OTHERWISE AS TO ANY MATTERS RELATED TO THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS, INCLUDING BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE FLUX.1 [dev] MODEL AND FLUX CONTENT FILTERS WILL BE ERROR FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY PARTICULAR RESULTS.
6. LIMITATION OF LIABILITY. TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS OR DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL, EXEMPLARY, INCIDENTAL, PUNITIVE OR SPECIAL DAMAGES OR LOST PROFITS, EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THE FLUX.1 [dev] MODEL, ITS CONSTITUENT COMPONENTS, FLUX CONTENT FILTERS, AND ANY OUTPUT (COLLECTIVELY, “MODEL MATERIALS”) ARE NOT DESIGNED OR INTENDED FOR USE IN ANY APPLICATION OR SITUATION WHERE FAILURE OR FAULT OF THE MODEL MATERIALS COULD REASONABLY BE ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL’S PRIVACY RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE (EACH, A “HIGH-RISK USE”). IF YOU ELECT TO USE ANY OF THE MODEL MATERIALS FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE SUCH THAT EVEN IF THERE IS A FAILURE OR FAULT IN ANY OF THE MODEL MATERIALS, THE SAFETY OF PERSONS OR PROPERTY AFFECTED BY THE ACTIVITY STAYS AT A LEVEL THAT IS REASONABLE, APPROPRIATE, AND LAWFUL FOR THE FIELD OF THE HIGH-RISK USE.
7. INDEMNIFICATION. You will indemnify, defend and hold harmless Company and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “Company Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any Company Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to (a) your access to or use of the FLUX.1 [dev] Model (including in connection with any Output, results or data generated from such access or use, or from your access or use of any FLUX Content Filters), including any High-Risk Use; (b) your Content Filters, including your failure to implement any Content Filters where required by this License such as in Section 2(e); (c) your violation of this License; or (d) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the Company Parties of any such Claims, and cooperate with Company Parties in defending such Claims. You will also grant the Company Parties sole control of the defense or settlement, at Company’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and Company or the other Company Parties.
8. Termination; Survival.
- a. This License will automatically terminate upon any breach by you of the terms of this License.
- b. We may terminate this License, in whole or in part, at any time upon notice (including electronic) to you.
- c. If you initiate any legal action or proceedings against Company or any other entity (including a cross-claim or counterclaim in a lawsuit), alleging that the FLUX.1 [dev] Model, any Derivative, or FLUX Content Filters, or any part thereof, infringe upon intellectual property or other rights owned or licensable by you, then any licenses granted to you under this License will immediately terminate as of the date such legal action or claim is filed or initiated.
- d. Upon termination of this License, you must cease all use, access or Distribution of the FLUX.1 [dev] Model, any Derivatives, and any FLUX Content Filters. The following sections survive termination of this License 2(c), 2(d), 4-11.
9. Third Party Materials. The FLUX.1 [dev] Model may contain third-party software or other components (including free and open source software) (all of the foregoing, “Third Party Materials”), which are subject to the license terms of the respective third-party licensors. Your dealings or correspondence with third parties and your use of or interaction with any Third Party Materials are solely between you and the third party. Company does not control or endorse, and makes no representations or warranties regarding, any Third Party Materials, and your access to and use of such Third Party Materials are at your own risk.
10. Trademarks. You have not been granted any trademark license as part of this License and may not use any name, logo or trademark associated with Company without the prior written permission of Company, except to the extent necessary to make the reference required in the Attribution Notice as specified above or as is reasonably necessary in describing the FLUX.1 [dev] Model and its creators.
11. General. This License will be governed and construed under the laws of the State of Delaware without regard to conflicts of law provisions. If any provision or part of a provision of this License is unlawful, void or unenforceable, that provision or part of the provision is deemed severed from this License, and will not affect the validity and enforceability of any remaining provisions. The failure of Company to exercise or enforce any right or provision of this License will not operate as a waiver of such right or provision. This License does not confer any third-party beneficiary rights upon any other person or entity. This License, together with the documentation, contains the entire understanding between you and Company regarding the subject matter of this License, and supersedes all other written or oral agreements and understandings between you and Company regarding such subject matter.
================================================
FILE: model_licenses/LICENSE-FLUX1-schnell
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of this License; and
You must cause any modified files to carry prominent notices stating that You changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
================================================
FILE: pyproject.toml
================================================
[project]
name = "flux"
authors = [
{ name = "Black Forest Labs", email = "support@blackforestlabs.ai" },
]
description = "Inference codebase for FLUX"
readme = "README.md"
requires-python = ">=3.10"
license = { file = "LICENSE.md" }
dynamic = ["version"]
dependencies = [
"accelerate",
"einops",
"fire >= 0.6.0",
"huggingface-hub",
"safetensors",
"sentencepiece",
"transformers >= 4.45.2",
"tokenizers",
"protobuf",
"requests",
"invisible-watermark",
"ruff == 0.6.8",
"accelerate",
]
[project.optional-dependencies]
torch = [
"torch == 2.6.0",
"torchvision",
]
streamlit = [
"streamlit",
"streamlit-drawable-canvas",
"streamlit-keyup",
]
gradio = [
"gradio",
]
tensorrt = [
"tensorrt-cu12 == 10.12.0.36",
"colored",
"opencv-python-headless==4.8.0.74",
"onnx >=1.18.0",
"onnxruntime ~= 1.22.0",
"onnxruntime-gpu ~= 1.22.0",
"onnx-graphsurgeon",
"polygraphy >= 0.49.22",
]
all = [
"flux[gradio]",
"flux[streamlit]",
"flux[torch]",
]
[project.scripts]
flux = "flux.cli:app"
[build-system]
build-backend = "setuptools.build_meta"
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
[tool.ruff]
line-length = 110
target-version = "py310"
extend-exclude = ["/usr/lib/*"]
[tool.ruff.lint]
ignore = [
"E501", # line too long - will be fixed in format
]
[tool.ruff.format]
quote-style = "double"
indent-style = "space"
line-ending = "auto"
skip-magic-trailing-comma = false
docstring-code-format = true
exclude = [
"src/flux/_version.py", # generated by setuptools_scm
]
[tool.ruff.lint.isort]
combine-as-imports = true
force-wrap-aliases = true
known-local-folder = ["src"]
known-first-party = ["flux"]
[tool.pyright]
include = ["src"]
exclude = [
"**/__pycache__", # cache directories
"./typings", # generated type stubs
]
stubPath = "./typings"
[tool.tomlsort]
in_place = true
no_sort_tables = true
spaces_before_inline_comment = 1
spaces_indent_inline_array = 2
trailing_comma_inline_array = true
sort_first = [
"project",
"build-system",
"tool.setuptools",
]
# needs to be last for CI reasons
[tool.setuptools_scm]
write_to = "src/flux/_version.py"
parentdir_prefix_version = "flux-"
fallback_version = "0.0.0"
version_scheme = "post-release"
================================================
FILE: setup.py
================================================
import setuptools
setuptools.setup()
================================================
FILE: src/flux/__init__.py
================================================
try:
from ._version import (
version as __version__, # type: ignore
version_tuple,
)
except ImportError:
__version__ = "unknown (no version information available)"
version_tuple = (0, 0, "unknown", "noinfo")
from pathlib import Path
PACKAGE = __package__.replace("_", "-")
PACKAGE_ROOT = Path(__file__).parent
================================================
FILE: src/flux/__main__.py
================================================
from fire import Fire
from .cli import main as cli_main
from .cli_control import main as control_main
from .cli_fill import main as fill_main
from .cli_kontext import main as kontext_main
from .cli_redux import main as redux_main
if __name__ == "__main__":
Fire(
{
"t2i": cli_main,
"control": control_main,
"fill": fill_main,
"kontext": kontext_main,
"redux": redux_main,
}
)
================================================
FILE: src/flux/cli.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
from flux.util import (
check_onnx_access_for_trt,
configs,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
NSFW_THRESHOLD = 0.85
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
@torch.inference_mode()
def main(
name: str = "flux-dev-krea",
width: int = 1360,
height: int = 768,
seed: int | None = None,
prompt: str = (
"a photo of a forest with mist swirling around the tree trunks. The word "
'"FLUX" is painted over it in big, red brush strokes with visible texture'
),
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int | None = None,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
name: Name of the model to load
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
trt: use TensorRT backend for optimized inference
trt_transformer_precision: specify transformer precision for inference
track_usage: track usage of the model for licensing purposes
"""
prompt = prompt.split("|")
if len(prompt) == 1:
prompt = prompt[0]
additional_prompts = None
else:
additional_prompts = prompt[1:]
prompt = prompt[0]
assert not (
(additional_prompts is not None) and loop
), "Do not provide additional prompts and set loop to True"
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
if num_steps is None:
num_steps = 4 if name == "flux-schnell" else 50
# allow for packing and conversion to latent space
height = 16 * (height // 16)
width = 16 * (width // 16)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
if not trt:
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
else:
# lazy import to make install optional
from flux.trt.trt_manager import ModuleName, TRTManager
# Check if we need ONNX model access (which requires authentication for FLUX models)
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
trt_ctx_manager = TRTManager(
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.getenv("TRT_T5_PRECISION", "bf16"),
)
engines = trt_ctx_manager.load_engines(
model_name=name,
module_names={
ModuleName.CLIP,
ModuleName.TRANSFORMER,
ModuleName.T5,
ModuleName.VAE,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
trt_static_batch=False,
trt_static_shape=False,
)
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
clip = engines[ModuleName.CLIP].to(torch_device)
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
)
if loop:
opts = parse_prompt(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare(t5, clip, x, prompt=opts.prompt)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
fn = output_name.format(idx=idx)
print(f"Done in {t1 - t0:.1f}s. Saving {fn}")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
elif additional_prompts:
next_prompt = additional_prompts.pop(0)
opts.prompt = next_prompt
else:
opts = None
if trt:
trt_ctx_manager.stop_runtime()
if __name__ == "__main__":
Fire(main)
================================================
FILE: src/flux/cli_control.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from transformers import pipeline
from flux.modules.image_embedders import CannyImageEncoder, DepthImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_control, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
lora_scale: float | None
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning image or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
options.img_cond_path = img_cond_path
break
return options
def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingOptions | None, bool]:
changed = False
if options is None:
return None, changed
user_question = "Next lora scale (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the lora scale or write a command starting with a slash:\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/q"):
print("Quitting")
return None, changed
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.lora_scale = float(prompt)
changed = True
return options, changed
@torch.inference_mode()
def main(
name: str,
width: int = 1024,
height: int = 1024,
seed: int | None = None,
prompt: str = "a robot made out of gold",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 50,
loop: bool = False,
guidance: float | None = None,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/robot.webp",
lora_scale: float | None = 0.85,
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
**kwargs: dict | None,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
trt_transformer_precision: specify transformer precision for inference
track_usage: track usage of the model for licensing purposes
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
if "lora" in name:
assert not trt, "TRT does not support LORA"
assert name in [
"flux-dev-canny",
"flux-dev-depth",
"flux-dev-canny-lora",
"flux-dev-depth-lora",
], f"Got unknown model name: {name}"
if guidance is None:
if name in ["flux-dev-canny", "flux-dev-canny-lora"]:
guidance = 30.0
elif name in ["flux-dev-depth", "flux-dev-depth-lora"]:
guidance = 10.0
else:
raise NotImplementedError()
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
if name in ["flux-dev-depth", "flux-dev-depth-lora"]:
img_embedder = DepthImageEncoder(torch_device)
elif name in ["flux-dev-canny", "flux-dev-canny-lora"]:
img_embedder = CannyImageEncoder(torch_device)
else:
raise NotImplementedError()
if not trt:
# init all components
t5 = load_t5(torch_device, max_length=512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
else:
# lazy import to make install optional
from flux.trt.trt_manager import ModuleName, TRTManager
trt_ctx_manager = TRTManager(
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
)
engines = trt_ctx_manager.load_engines(
model_name=name,
module_names={
ModuleName.CLIP,
ModuleName.TRANSFORMER,
ModuleName.T5,
ModuleName.VAE,
ModuleName.VAE_ENCODER,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
custom_onnx_paths=os.environ.get("CUSTOM_ONNX_PATHS", ""),
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_static_batch=kwargs.get("static_batch", True),
trt_static_shape=kwargs.get("static_shape", True),
)
ae = engines[ModuleName.VAE].to(device="cpu" if offload else torch_device)
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
clip = engines[ModuleName.CLIP].to(torch_device)
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
# set lora scale
if "lora" in name and lora_scale is not None:
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(lora_scale)
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
lora_scale=lora_scale,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
if "lora" in name:
opts, changed = parse_lora_scale(opts)
if changed:
# update the lora scale:
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(opts.lora_scale)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp = prepare_control(
t5,
clip,
x,
prompt=opts.prompt,
ae=ae,
encoder=img_embedder,
img_cond_path=opts.img_cond_path,
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs and AE to CPU, load model to gpu
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
if "lora" in name:
opts, changed = parse_lora_scale(opts)
if changed:
# update the lora scale:
for _, module in model.named_modules():
if hasattr(module, "set_scale"):
module.set_scale(opts.lora_scale)
else:
opts = None
if trt:
trt_ctx_manager.stop_runtime()
if __name__ == "__main__":
Fire(main)
================================================
FILE: src/flux/cli_fill.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from PIL import Image
from transformers import pipeline
from flux.sampling import denoise, get_noise, get_schedule, prepare_fill, unpack
from flux.util import configs, load_ae, load_clip, load_flow_model, load_t5, save_image
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
img_mask_path: str
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning image or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
else:
with Image.open(img_cond_path) as img:
width, height = img.size
if width % 32 != 0 or height % 32 != 0:
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
continue
options.img_cond_path = img_cond_path
break
return options
def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning mask (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning mask or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_mask_path = input(user_question)
if img_mask_path.startswith("/"):
if img_mask_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_mask_path.startswith("/h"):
print(f"Got invalid command '{img_mask_path}'\n{usage}")
print(usage)
continue
if img_mask_path == "":
break
if not os.path.isfile(img_mask_path) or not img_mask_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_mask_path}' does not exist or is not a valid image file")
continue
else:
with Image.open(img_mask_path) as img:
width, height = img.size
if width % 32 != 0 or height % 32 != 0:
print(f"Image dimensions must be divisible by 32, got {width}x{height}")
continue
else:
with Image.open(options.img_cond_path) as img_cond:
img_cond_width, img_cond_height = img_cond.size
if width != img_cond_width or height != img_cond_height:
print(
f"Mask dimensions must match conditioning image, got {width}x{height} and {img_cond_width}x{img_cond_height}"
)
continue
options.img_mask_path = img_mask_path
break
return options
@torch.inference_mode()
def main(
seed: int | None = None,
prompt: str = "a white paper cup",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 50,
loop: bool = False,
guidance: float = 30.0,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/cup.png",
img_mask_path: str = "assets/cup_mask.png",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image. This demo assumes that the conditioning image and mask have
the same shape and that height and width are divisible by 32.
Args:
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
img_mask_path: path to conditioning mask (jpeg/png/webp)
track_usage: track usage of the model for licensing purposes
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
name = "flux-dev-fill"
if name not in configs:
available = ", ".join(configs.keys())
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
# init all components
t5 = load_t5(torch_device, max_length=128)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
rng = torch.Generator(device="cpu")
with Image.open(img_cond_path) as img:
width, height = img.size
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
img_mask_path=img_mask_path,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
with Image.open(opts.img_cond_path) as img:
width, height = img.size
opts.height = height
opts.width = width
opts = parse_img_mask_path(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp = prepare_fill(
t5,
clip,
x,
prompt=opts.prompt,
ae=ae,
img_cond_path=opts.img_cond_path,
mask_path=opts.img_mask_path,
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs and AE to CPU, load model to gpu
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
with Image.open(opts.img_cond_path) as img:
width, height = img.size
opts.height = height
opts.width = width
opts = parse_img_mask_path(opts)
else:
opts = None
if __name__ == "__main__":
Fire(main)
================================================
FILE: src/flux/cli_kontext.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from flux.content_filters import PixtralContentFilter
from flux.sampling import denoise, get_schedule, prepare_kontext, unpack
from flux.util import (
aspect_ratio_to_height_width,
check_onnx_access_for_trt,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
@dataclass
class SamplingOptions:
prompt: str
width: int | None
height: int | None
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Next prompt (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the prompt or write a command starting with a slash:\n"
"- '/ar <width>:<height>' will set the aspect ratio of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/ar"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, ratio_prompt = prompt.split()
if ratio_prompt == "auto":
options.width = None
options.height = None
print("Setting resolution to input image resolution.")
else:
options.width, options.height = aspect_ratio_to_height_width(ratio_prompt)
print(f"Setting resolution to {options.width} x {options.height}.")
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
if height == "auto":
options.height = None
else:
options.height = 16 * (int(height) // 16)
if options.height is not None and options.width is not None:
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
else:
print(f"Setting resolution to {options.width} x {options.height}.")
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
if prompt != "":
options.prompt = prompt
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next input image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write a path to an image directly, leave this field empty "
"to repeat the last input image or write a command starting with a slash:\n"
"- '/q' to quit\n\n"
"The input image will be edited by FLUX.1 Kontext creating a new image based"
"on your instruction prompt."
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
options.img_cond_path = img_cond_path
break
return options
@torch.inference_mode()
def main(
name: str = "flux-dev-kontext",
aspect_ratio: str | None = None,
seed: int | None = None,
prompt: str = "replace the logo with the text 'Black Forest Labs'",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int = 30,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/cup.png",
trt: bool = False,
trt_transformer_precision: str = "bf16",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
height: height of the sample in pixels (should be a multiple of 16), None
defaults to the size of the conditioning
width: width of the sample in pixels (should be a multiple of 16), None
defaults to the size of the conditioning
seed: Set a seed for sampling
output_name: where to save the output image, `{idx}` will be replaced
by the index of the sample
prompt: Prompt used for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
trt: use TensorRT backend for optimized inference
track_usage: track usage of the model for licensing purposes
"""
assert name == "flux-dev-kontext", f"Got unknown model name: {name}"
torch_device = torch.device(device)
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
if aspect_ratio is None:
width = None
height = None
else:
width, height = aspect_ratio_to_height_width(aspect_ratio)
if not trt:
t5 = load_t5(torch_device, max_length=512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
else:
# lazy import to make install optional
from flux.trt.trt_manager import ModuleName, TRTManager
# Check if we need ONNX model access (which requires authentication for FLUX models)
onnx_dir = check_onnx_access_for_trt(name, trt_transformer_precision)
trt_ctx_manager = TRTManager(
trt_transformer_precision=trt_transformer_precision,
trt_t5_precision=os.environ.get("TRT_T5_PRECISION", "bf16"),
)
engines = trt_ctx_manager.load_engines(
model_name=name,
module_names={
ModuleName.CLIP,
ModuleName.TRANSFORMER,
ModuleName.T5,
},
engine_dir=os.environ.get("TRT_ENGINE_DIR", "./engines"),
custom_onnx_paths=onnx_dir or os.environ.get("CUSTOM_ONNX_PATHS", ""),
trt_image_height=height,
trt_image_width=width,
trt_batch_size=1,
trt_timing_cache=os.getenv("TRT_TIMING_CACHE_FILE", None),
trt_static_batch=False,
trt_static_shape=False,
)
model = engines[ModuleName.TRANSFORMER].to(device="cpu" if offload else torch_device)
clip = engines[ModuleName.CLIP].to(torch_device)
t5 = engines[ModuleName.T5].to(device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
content_filter = PixtralContentFilter(torch.device("cpu"))
rng = torch.Generator(device="cpu")
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
if content_filter.test_txt(opts.prompt):
print("Your prompt has been automatically flagged. Please choose another prompt.")
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
continue
if content_filter.test_image(opts.img_cond_path):
print("Your input image has been automatically flagged. Please choose another image.")
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
continue
if offload:
t5, clip, ae = t5.to(torch_device), clip.to(torch_device), ae.to(torch_device)
inp, height, width = prepare_kontext(
t5=t5,
clip=clip,
prompt=opts.prompt,
ae=ae,
img_cond_path=opts.img_cond_path,
target_width=opts.width,
target_height=opts.height,
bs=1,
seed=opts.seed,
device=torch_device,
)
from safetensors.torch import save_file
save_file({k: v.cpu().contiguous() for k, v in inp.items()}, "output/noise.sft")
inp.pop("img_cond_orig")
opts.seed = None
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs and AE to CPU, load model to gpu
if offload:
t5, clip, ae = t5.cpu(), clip.cpu(), ae.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
t00 = time.time()
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
torch.cuda.synchronize()
t01 = time.time()
print(f"Denoising took {t01 - t00:.3f}s")
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), height, width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
ae_dev_t0 = time.perf_counter()
x = ae.decode(x)
torch.cuda.synchronize()
ae_dev_t1 = time.perf_counter()
print(f"AE decode took {ae_dev_t1 - ae_dev_t0:.3f}s")
if content_filter.test_image(x.cpu()):
print(
"Your output image has been automatically flagged. Choose another prompt/image or try again."
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
continue
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
None, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
if __name__ == "__main__":
Fire(main)
================================================
FILE: src/flux/cli_redux.py
================================================
import os
import re
import time
from dataclasses import dataclass
from glob import iglob
import torch
from fire import Fire
from transformers import pipeline
from flux.modules.image_embedders import ReduxImageEncoder
from flux.sampling import denoise, get_noise, get_schedule, prepare_redux, unpack
from flux.util import (
get_checkpoint_path,
load_ae,
load_clip,
load_flow_model,
load_t5,
save_image,
)
@dataclass
class SamplingOptions:
prompt: str
width: int
height: int
num_steps: int
guidance: float
seed: int | None
img_cond_path: str
def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
user_question = "Write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Leave this field empty to do nothing "
"or write a command starting with a slash:\n"
"- '/w <width>' will set the width of the generated image\n"
"- '/h <height>' will set the height of the generated image\n"
"- '/s <seed>' sets the next seed\n"
"- '/g <guidance>' sets the guidance (flux-dev only)\n"
"- '/n <steps>' sets the number of steps\n"
"- '/q' to quit"
)
while (prompt := input(user_question)).startswith("/"):
if prompt.startswith("/w"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, width = prompt.split()
options.width = 16 * (int(width) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/h"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, height = prompt.split()
options.height = 16 * (int(height) // 16)
print(
f"Setting resolution to {options.width} x {options.height} "
f"({options.height * options.width / 1e6:.2f}MP)"
)
elif prompt.startswith("/g"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, guidance = prompt.split()
options.guidance = float(guidance)
print(f"Setting guidance to {options.guidance}")
elif prompt.startswith("/s"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, seed = prompt.split()
options.seed = int(seed)
print(f"Setting seed to {options.seed}")
elif prompt.startswith("/n"):
if prompt.count(" ") != 1:
print(f"Got invalid command '{prompt}'\n{usage}")
continue
_, steps = prompt.split()
options.num_steps = int(steps)
print(f"Setting number of steps to {options.num_steps}")
elif prompt.startswith("/q"):
print("Quitting")
return None
else:
if not prompt.startswith("/h"):
print(f"Got invalid command '{prompt}'\n{usage}")
print(usage)
return options
def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOptions | None:
if options is None:
return None
user_question = "Next conditioning image (write /h for help, /q to quit and leave empty to repeat):\n"
usage = (
"Usage: Either write your prompt directly, leave this field empty "
"to repeat the conditioning image or write a command starting with a slash:\n"
"- '/q' to quit"
)
while True:
img_cond_path = input(user_question)
if img_cond_path.startswith("/"):
if img_cond_path.startswith("/q"):
print("Quitting")
return None
else:
if not img_cond_path.startswith("/h"):
print(f"Got invalid command '{img_cond_path}'\n{usage}")
print(usage)
continue
if img_cond_path == "":
break
if not os.path.isfile(img_cond_path) or not img_cond_path.lower().endswith(
(".jpg", ".jpeg", ".png", ".webp")
):
print(f"File '{img_cond_path}' does not exist or is not a valid image file")
continue
options.img_cond_path = img_cond_path
break
return options
@torch.inference_mode()
def main(
name: str = "flux-dev",
width: int = 1360,
height: int = 768,
seed: int | None = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu",
num_steps: int | None = None,
loop: bool = False,
guidance: float = 2.5,
offload: bool = False,
output_dir: str = "output",
add_sampling_metadata: bool = True,
img_cond_path: str = "assets/robot.webp",
track_usage: bool = False,
):
"""
Sample the flux model. Either interactively (set `--loop`) or run for a
single image.
Args:
name: Name of the base model to use (either 'flux-dev' or 'flux-schnell')
height: height of the sample in pixels (should be a multiple of 16)
width: width of the sample in pixels (should be a multiple of 16)
seed: Set a seed for sampling
device: Pytorch device
num_steps: number of sampling steps (default 4 for schnell, 50 for guidance distilled)
loop: start an interactive session and sample multiple times
guidance: guidance value used for guidance distillation
offload: offload models to CPU when not in use
output_dir: where to save the output images
add_sampling_metadata: Add the prompt to the image Exif metadata
img_cond_path: path to conditioning image (jpeg/png/webp)
track_usage: track usage of the model for licensing purposes
"""
nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
if name not in (available := ["flux-dev", "flux-schnell"]):
raise ValueError(f"Got unknown model name: {name}, chose from {available}")
torch_device = torch.device(device)
if num_steps is None:
num_steps = 4 if name == "flux-schnell" else 50
output_name = os.path.join(output_dir, "img_{idx}.jpg")
if not os.path.exists(output_dir):
os.makedirs(output_dir)
idx = 0
else:
fns = [fn for fn in iglob(output_name.format(idx="*")) if re.search(r"img_[0-9]+\.jpg$", fn)]
if len(fns) > 0:
idx = max(int(fn.split("_")[-1].split(".")[0]) for fn in fns) + 1
else:
idx = 0
# init all components
t5 = load_t5(torch_device, max_length=256 if name == "flux-schnell" else 512)
clip = load_clip(torch_device)
model = load_flow_model(name, device="cpu" if offload else torch_device)
ae = load_ae(name, device="cpu" if offload else torch_device)
# Download and initialize the Redux adapter
redux_path = str(
get_checkpoint_path("black-forest-labs/FLUX.1-Redux-dev", "flux1-redux-dev.safetensors", "FLUX_REDUX")
)
img_embedder = ReduxImageEncoder(torch_device, redux_path=redux_path)
rng = torch.Generator(device="cpu")
prompt = ""
opts = SamplingOptions(
prompt=prompt,
width=width,
height=height,
num_steps=num_steps,
guidance=guidance,
seed=seed,
img_cond_path=img_cond_path,
)
if loop:
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
while opts is not None:
if opts.seed is None:
opts.seed = rng.seed()
print(f"Generating with seed {opts.seed}:\n{opts.prompt}")
t0 = time.perf_counter()
# prepare input
x = get_noise(
1,
opts.height,
opts.width,
device=torch_device,
dtype=torch.bfloat16,
seed=opts.seed,
)
opts.seed = None
if offload:
ae = ae.cpu()
torch.cuda.empty_cache()
t5, clip = t5.to(torch_device), clip.to(torch_device)
inp = prepare_redux(
t5,
clip,
x,
prompt=opts.prompt,
encoder=img_embedder,
img_cond_path=opts.img_cond_path,
)
timesteps = get_schedule(opts.num_steps, inp["img"].shape[1], shift=(name != "flux-schnell"))
# offload TEs to CPU, load model to gpu
if offload:
t5, clip = t5.cpu(), clip.cpu()
torch.cuda.empty_cache()
model = model.to(torch_device)
# denoise initial noise
x = denoise(model, **inp, timesteps=timesteps, guidance=opts.guidance)
# offload model, load autoencoder to gpu
if offload:
model.cpu()
torch.cuda.empty_cache()
ae.decoder.to(x.device)
# decode latents to pixel space
x = unpack(x.float(), opts.height, opts.width)
with torch.autocast(device_type=torch_device.type, dtype=torch.bfloat16):
x = ae.decode(x)
if torch.cuda.is_available():
torch.cuda.synchronize()
t1 = time.perf_counter()
print(f"Done in {t1 - t0:.1f}s")
idx = save_image(
nsfw_classifier, name, output_name, idx, x, add_sampling_metadata, prompt, track_usage=track_usage
)
if loop:
print("-" * 80)
opts = parse_prompt(opts)
opts = parse_img_cond_path(opts)
else:
opts = None
if __name__ == "__main__":
Fire(main)
================================================
FILE: src/flux/content_filters.py
================================================
import torch
from einops import rearrange
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, pipeline
PROMPT_IMAGE_INTEGRITY = """
Task: Analyze an image to identify potential copyright concerns or depictions of public figures.
Output: Respond with only "yes" or "no"
Criteria for "yes":
- The image contains a recognizable character from copyrighted media (movies, TV, comics, games, etc.)
- The image displays a trademarked logo or brand
- The image depicts a recognizable public figure (celebrities, politicians, athletes, influencers, historical figures, etc.)
Criteria for "no":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person depicted
2. General references to demographics or characteristics are not sufficient
3. Base your decision solely on visual content, not interpretation
4. Provide only the one-word answer: "yes" or "no"
""".strip()
PROMPT_IMAGE_INTEGRITY_FOLLOW_UP = "Does this image have copyright concerns or includes public figures?"
PROMPT_TEXT_INTEGRITY = """
Task: Analyze a text prompt to identify potential copyright concerns or requests to depict living public figures.
Output: Respond with only "yes" or "no"
Criteria for "Yes":
- The prompt explicitly names a character from copyrighted media (movies, TV, comics, games, etc.)
- The prompt explicitly mentions a trademarked logo or brand
- The prompt names or describes a specific living public figure (celebrities, politicians, athletes, influencers, etc.)
Criteria for "No":
- All other cases
- When you cannot identify the specific copyrighted work or named individual
Critical Requirements:
1. You must be able to name the exact copyrighted work or specific person referenced
2. General demographic descriptions or characteristics are not sufficient
3. Analyze only the prompt text, not potential image outcomes
4. Provide only the one-word answer: "yes" or "no"
The prompt to check is:
-----
{prompt}
-----
Does this prompt have copyright concerns or includes public figures?
""".strip()
class PixtralContentFilter(torch.nn.Module):
def __init__(
self,
device: torch.device = torch.device("cpu"),
nsfw_threshold: float = 0.85,
):
super().__init__()
model_id = "mistral-community/pixtral-12b"
self.processor = AutoProcessor.from_pretrained(model_id)
self.model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map=device)
self.yes_token, self.no_token = self.processor.tokenizer.encode(["yes", "no"])
self.nsfw_classifier = pipeline(
"image-classification", model="Falconsai/nsfw_image_detection", device=device
)
self.nsfw_threshold = nsfw_threshold
def yes_no_logit_processor(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
) -> torch.FloatTensor:
"""
Sets all tokens but yes/no to the minimum.
"""
scores_yes_token = scores[:, self.yes_token].clone()
scores_no_token = scores[:, self.no_token].clone()
scores_min = scores.min()
scores[:, :] = scores_min - 1
scores[:, self.yes_token] = scores_yes_token
scores[:, self.no_token] = scores_no_token
return scores
def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
if isinstance(image, torch.Tensor):
image = rearrange(image[0].clamp(-1.0, 1.0), "c h w -> h w c")
image = Image.fromarray((127.5 * (image + 1.0)).cpu().byte().numpy())
elif isinstance(image, str):
image = Image.open(image)
classification = next(c for c in self.nsfw_classifier(image) if c["label"] == "nsfw")
if classification["score"] > self.nsfw_threshold:
return True
# 512^2 pixels are enough for checking
w, h = image.size
f = (512**2 / (w * h)) ** 0.5
image = image.resize((int(f * w), int(f * h)))
chat = [
{
"role": "user",
"content": [
{
"type": "text",
"content": PROMPT_IMAGE_INTEGRITY,
},
{
"type": "image",
"image": image,
},
{
"type": "text",
"content": PROMPT_IMAGE_INTEGRITY_FOLLOW_UP,
},
],
}
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token
def test_txt(self, txt: str) -> bool:
chat = [
{
"role": "user",
"content": [
{
"type": "text",
"content": PROMPT_TEXT_INTEGRITY.format(prompt=txt),
},
],
}
]
inputs = self.processor.apply_chat_template(
chat,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(self.model.device)
generate_ids = self.model.generate(
**inputs,
max_new_tokens=1,
logits_processor=[self.yes_no_logit_processor],
do_sample=False,
)
return generate_ids[0, -1].item() == self.yes_token
================================================
FILE: src/flux/math.py
================================================
import torch
from einops import rearrange
from torch import Tensor
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "B H L D -> B L (H D)")
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos, omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.float()
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
================================================
FILE: src/flux/model.py
================================================
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from flux.modules.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from flux.modules.lora import LinearLora, replace_linear_with_lora
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = params.out_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor | None = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
class FluxLoraWrapper(Flux):
def __init__(
self,
lora_rank: int = 128,
lora_scale: float = 1.0,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.lora_rank = lora_rank
replace_linear_with_lora(
self,
max_rank=lora_rank,
scale=lora_scale,
)
def set_lora_scale(self, scale: float) -> None:
for module in self.modules():
if isinstance(module, LinearLora):
module.set_scale(scale=scale)
================================================
FILE: src/flux/modules/autoencoder.py
================================================
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: list[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
# no asymmetric padding in torch conv, must do it ourselves
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: Tensor):
pad = (0, 1, 0, 1)
gitextract_5wlz5y8p/
├── .github/
│ └── workflows/
│ └── ci.yaml
├── .gitignore
├── LICENSE
├── README.md
├── demo_gr.py
├── demo_st.py
├── demo_st_fill.py
├── docs/
│ ├── fill.md
│ ├── image-editing.md
│ ├── image-variation.md
│ ├── structural-conditioning.md
│ └── text-to-image.md
├── model_cards/
│ ├── FLUX.1-Krea-dev.md
│ ├── FLUX.1-dev.md
│ ├── FLUX.1-kontext-dev.md
│ └── FLUX.1-schnell.md
├── model_licenses/
│ ├── LICENSE-FLUX1-dev
│ └── LICENSE-FLUX1-schnell
├── pyproject.toml
├── setup.py
└── src/
└── flux/
├── __init__.py
├── __main__.py
├── cli.py
├── cli_control.py
├── cli_fill.py
├── cli_kontext.py
├── cli_redux.py
├── content_filters.py
├── math.py
├── model.py
├── modules/
│ ├── autoencoder.py
│ ├── conditioner.py
│ ├── image_embedders.py
│ ├── layers.py
│ └── lora.py
├── sampling.py
├── trt/
│ ├── __init__.py
│ ├── engine/
│ │ ├── __init__.py
│ │ ├── base_engine.py
│ │ ├── clip_engine.py
│ │ ├── t5_engine.py
│ │ ├── transformer_engine.py
│ │ └── vae_engine.py
│ ├── trt_config/
│ │ ├── __init__.py
│ │ ├── base_trt_config.py
│ │ ├── clip_trt_config.py
│ │ ├── t5_trt_config.py
│ │ ├── transformer_trt_config.py
│ │ └── vae_trt_config.py
│ └── trt_manager.py
└── util.py
SYMBOL INDEX (266 symbols across 29 files)
FILE: demo_gr.py
function get_models (line 27) | def get_models(name: str, device: torch.device, offload: bool, is_schnel...
class FluxGenerator (line 36) | class FluxGenerator:
method __init__ (line 37) | def __init__(self, model_name: str, device: str, offload: bool, track_...
method generate_image (line 51) | def generate_image(
function create_demo (line 175) | def create_demo(
FILE: demo_st.py
function get_models (line 32) | def get_models(name: str, device: torch.device, offload: bool, is_schnel...
function get_image (line 41) | def get_image() -> torch.Tensor | None:
function main (line 58) | def main(
function app (line 292) | def app():
FILE: demo_st_fill.py
function add_border_and_mask (line 31) | def add_border_and_mask(image, zoom_all=1.0, zoom_left=0, zoom_right=0, ...
function get_models (line 83) | def get_models(name: str, device: torch.device, offload: bool):
function resize (line 92) | def resize(img: Image.Image, min_mp: float = 0.5, max_mp: float = 2.0) -...
function clear_canvas_state (line 116) | def clear_canvas_state():
function set_new_image (line 124) | def set_new_image(img: Image.Image):
function downscale_image (line 131) | def downscale_image(img: Image.Image, scale_factor: float) -> Image.Image:
function main (line 148) | def main(
function app (line 497) | def app():
FILE: src/flux/cli.py
class SamplingOptions (line 26) | class SamplingOptions:
function parse_prompt (line 35) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
function main (line 103) | def main(
FILE: src/flux/cli_control.py
class SamplingOptions (line 17) | class SamplingOptions:
function parse_prompt (line 28) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
function parse_img_cond_path (line 95) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
function parse_lora_scale (line 134) | def parse_lora_scale(options: SamplingOptions | None) -> tuple[SamplingO...
function main (line 162) | def main(
FILE: src/flux/cli_fill.py
class SamplingOptions (line 17) | class SamplingOptions:
function parse_prompt (line 28) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
function parse_img_cond_path (line 73) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
function parse_img_mask_path (line 119) | def parse_img_mask_path(options: SamplingOptions | None) -> SamplingOpti...
function main (line 175) | def main(
FILE: src/flux/cli_kontext.py
class SamplingOptions (line 24) | class SamplingOptions:
function parse_prompt (line 34) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
function parse_img_cond_path (line 108) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
function main (line 150) | def main(
FILE: src/flux/cli_redux.py
class SamplingOptions (line 24) | class SamplingOptions:
function parse_prompt (line 34) | def parse_prompt(options: SamplingOptions) -> SamplingOptions | None:
function parse_img_cond_path (line 99) | def parse_img_cond_path(options: SamplingOptions | None) -> SamplingOpti...
function main (line 139) | def main(
FILE: src/flux/content_filters.py
class PixtralContentFilter (line 59) | class PixtralContentFilter(torch.nn.Module):
method __init__ (line 60) | def __init__(
method yes_no_logit_processor (line 78) | def yes_no_logit_processor(
method test_image (line 92) | def test_image(self, image: Image.Image | str | torch.Tensor) -> bool:
method test_txt (line 144) | def test_txt(self, txt: str) -> bool:
FILE: src/flux/math.py
function attention (line 6) | def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
function rope (line 15) | def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
function apply_rope (line 25) | def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tenso...
FILE: src/flux/model.py
class FluxParams (line 18) | class FluxParams:
class Flux (line 34) | class Flux(nn.Module):
method __init__ (line 39) | def __init__(self, params: FluxParams):
method forward (line 84) | def forward(
class FluxLoraWrapper (line 122) | class FluxLoraWrapper(Flux):
method __init__ (line 123) | def __init__(
method set_lora_scale (line 140) | def set_lora_scale(self, scale: float) -> None:
FILE: src/flux/modules/autoencoder.py
class AutoEncoderParams (line 9) | class AutoEncoderParams:
function swish (line 21) | def swish(x: Tensor) -> Tensor:
class AttnBlock (line 25) | class AttnBlock(nn.Module):
method __init__ (line 26) | def __init__(self, in_channels: int):
method attention (line 37) | def attention(self, h_: Tensor) -> Tensor:
method forward (line 51) | def forward(self, x: Tensor) -> Tensor:
class ResnetBlock (line 55) | class ResnetBlock(nn.Module):
method __init__ (line 56) | def __init__(self, in_channels: int, out_channels: int):
method forward (line 69) | def forward(self, x):
class Downsample (line 85) | class Downsample(nn.Module):
method __init__ (line 86) | def __init__(self, in_channels: int):
method forward (line 91) | def forward(self, x: Tensor):
class Upsample (line 98) | class Upsample(nn.Module):
method __init__ (line 99) | def __init__(self, in_channels: int):
method forward (line 103) | def forward(self, x: Tensor):
class Encoder (line 109) | class Encoder(nn.Module):
method __init__ (line 110) | def __init__(
method forward (line 159) | def forward(self, x: Tensor) -> Tensor:
class Decoder (line 183) | class Decoder(nn.Module):
method __init__ (line 184) | def __init__(
method forward (line 237) | def forward(self, z: Tensor) -> Tensor:
class DiagonalGaussian (line 267) | class DiagonalGaussian(nn.Module):
method __init__ (line 268) | def __init__(self, sample: bool = True, chunk_dim: int = 1):
method forward (line 273) | def forward(self, z: Tensor) -> Tensor:
class AutoEncoder (line 282) | class AutoEncoder(nn.Module):
method __init__ (line 283) | def __init__(self, params: AutoEncoderParams, sample_z: bool = False):
method encode (line 308) | def encode(self, x: Tensor) -> Tensor:
method decode (line 313) | def decode(self, z: Tensor) -> Tensor:
method forward (line 317) | def forward(self, x: Tensor) -> Tensor:
FILE: src/flux/modules/conditioner.py
class HFEmbedder (line 5) | class HFEmbedder(nn.Module):
method __init__ (line 6) | def __init__(self, version: str, max_length: int, **hf_kwargs):
method forward (line 21) | def forward(self, text: list[str]) -> Tensor:
FILE: src/flux/modules/image_embedders.py
class DepthImageEncoder (line 13) | class DepthImageEncoder:
method __init__ (line 16) | def __init__(self, device):
method __call__ (line 21) | def __call__(self, img: torch.Tensor) -> torch.Tensor:
class CannyImageEncoder (line 36) | class CannyImageEncoder:
method __init__ (line 37) | def __init__(
method __call__ (line 47) | def __call__(self, img: torch.Tensor) -> torch.Tensor:
class ReduxImageEncoder (line 64) | class ReduxImageEncoder(nn.Module):
method __init__ (line 67) | def __init__(
method __call__ (line 92) | def __call__(self, x: Image.Image) -> torch.Tensor:
FILE: src/flux/modules/layers.py
class EmbedND (line 11) | class EmbedND(nn.Module):
method __init__ (line 12) | def __init__(self, dim: int, theta: int, axes_dim: list[int]):
method forward (line 18) | def forward(self, ids: Tensor) -> Tensor:
function timestep_embedding (line 28) | def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: fl...
class MLPEmbedder (line 52) | class MLPEmbedder(nn.Module):
method __init__ (line 53) | def __init__(self, in_dim: int, hidden_dim: int):
method forward (line 59) | def forward(self, x: Tensor) -> Tensor:
class RMSNorm (line 63) | class RMSNorm(torch.nn.Module):
method __init__ (line 64) | def __init__(self, dim: int):
method forward (line 68) | def forward(self, x: Tensor):
class QKNorm (line 75) | class QKNorm(torch.nn.Module):
method __init__ (line 76) | def __init__(self, dim: int):
method forward (line 81) | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Te...
class SelfAttention (line 87) | class SelfAttention(nn.Module):
method __init__ (line 88) | def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
method forward (line 97) | def forward(self, x: Tensor, pe: Tensor) -> Tensor:
class ModulationOut (line 107) | class ModulationOut:
class Modulation (line 113) | class Modulation(nn.Module):
method __init__ (line 114) | def __init__(self, dim: int, double: bool):
method forward (line 120) | def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut |...
class DoubleStreamBlock (line 129) | class DoubleStreamBlock(nn.Module):
method __init__ (line 130) | def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float,...
method forward (line 158) | def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -...
class SingleStreamBlock (line 194) | class SingleStreamBlock(nn.Module):
method __init__ (line 200) | def __init__(
method forward (line 227) | def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
class LastLayer (line 242) | class LastLayer(nn.Module):
method __init__ (line 243) | def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
method forward (line 249) | def forward(self, x: Tensor, vec: Tensor) -> Tensor:
FILE: src/flux/modules/lora.py
function replace_linear_with_lora (line 5) | def replace_linear_with_lora(
class LinearLora (line 34) | class LinearLora(nn.Linear):
method __init__ (line 35) | def __init__(
method set_scale (line 84) | def set_scale(self, scale: float) -> None:
method forward (line 88) | def forward(self, input: torch.Tensor) -> torch.Tensor:
FILE: src/flux/sampling.py
function get_noise (line 17) | def get_noise(
function prepare (line 36) | def prepare(t5: HFEmbedder, clip: HFEmbedder, img: Tensor, prompt: str |...
function prepare_control (line 70) | def prepare_control(
function prepare_fill (line 107) | def prepare_fill(
function prepare_redux (line 160) | def prepare_redux(
function prepare_kontext (line 210) | def prepare_kontext(
function time_shift (line 277) | def time_shift(mu: float, sigma: float, t: Tensor):
function get_lin_function (line 281) | def get_lin_function(
function get_schedule (line 289) | def get_schedule(
function denoise (line 308) | def denoise(
function unpack (line 356) | def unpack(x: Tensor, height: int, width: int) -> Tensor:
FILE: src/flux/trt/engine/base_engine.py
class SharedMemory (line 32) | class SharedMemory(object):
method __new__ (line 33) | def __new__(cls, *args, **kwargs):
method __init__ (line 39) | def __init__(self, size: int, device=torch.device("cuda")):
method resize (line 48) | def resize(self, name: str, size: int):
method reset (line 54) | def reset(self, name: str):
method deallocate (line 61) | def deallocate(self):
method shared_device_memory (line 72) | def shared_device_memory(self):
method __str__ (line 75) | def __str__(self):
class BaseEngine (line 99) | class BaseEngine(ABC):
method trt_datatype_to_torch (line 101) | def trt_datatype_to_torch(datatype):
method cpu (line 118) | def cpu(self) -> "BaseEngine":
method cuda (line 122) | def cuda(self) -> "BaseEngine":
method to (line 126) | def to(self, device: str | torch.device) -> "BaseEngine":
class Engine (line 130) | class Engine(BaseEngine):
method __init__ (line 131) | def __init__(
method __call__ (line 157) | def __call__(self, *args, **Kwargs) -> torch.Tensor | dict[str, torch....
method cpu (line 160) | def cpu(self) -> "Engine":
method cuda (line 171) | def cuda(self) -> "Engine":
method to (line 182) | def to(self, device: str | torch.device) -> "Engine":
method deactivate (line 194) | def deactivate(self):
method allocate_buffers (line 198) | def allocate_buffers(
method get_dtype (line 220) | def get_dtype(self, tensor_name: str):
method override_shapes (line 223) | def override_shapes(self, feed_dict: Dict[str, torch.Tensor]):
method deallocate_buffers (line 249) | def deallocate_buffers(self):
method device_memory_size (line 258) | def device_memory_size(self):
method calculate_input_hash (line 267) | def calculate_input_hash(feed_dict: Dict[str, torch.Tensor]):
method _capture_cuda_graph (line 270) | def _capture_cuda_graph(self):
method infer (line 280) | def infer(
method __str__ (line 297) | def __str__(self):
FILE: src/flux/trt/engine/clip_engine.py
class CLIPEngine (line 24) | class CLIPEngine(Engine):
method __init__ (line 25) | def __init__(self, trt_config: ClipConfig, stream: torch.cuda.Stream, ...
method __call__ (line 33) | def __call__(
FILE: src/flux/trt/engine/t5_engine.py
class T5Engine (line 24) | class T5Engine(Engine):
method __init__ (line 25) | def __init__(self, trt_config: T5Config, stream: torch.cuda.Stream, **...
method __call__ (line 33) | def __call__(
FILE: src/flux/trt/engine/transformer_engine.py
class TransformerEngine (line 23) | class TransformerEngine(Engine):
method __init__ (line 46) | def __init__(self, trt_config: TransformerConfig, stream: torch.cuda.S...
method dd_to_flux (line 50) | def dd_to_flux(self):
method flux_to_dd (line 54) | def flux_to_dd(self):
method __call__ (line 58) | def __call__(
FILE: src/flux/trt/engine/vae_engine.py
class VAEDecoder (line 23) | class VAEDecoder(Engine):
method __init__ (line 24) | def __init__(self, trt_config: VAEDecoderConfig, stream: torch.cuda.St...
method __call__ (line 28) | def __call__(
class VAEEncoder (line 39) | class VAEEncoder(Engine):
method __init__ (line 40) | def __init__(self, trt_config: VAEEncoderConfig, stream: torch.cuda.St...
method __call__ (line 44) | def __call__(
class VAEEngine (line 54) | class VAEEngine(BaseEngine):
method __init__ (line 55) | def __init__(
method decode (line 64) | def decode(self, z: torch.Tensor) -> torch.Tensor:
method encode (line 67) | def encode(self, x: torch.Tensor) -> torch.Tensor:
method cpu (line 71) | def cpu(self):
method cuda (line 77) | def cuda(self):
method to (line 83) | def to(self, device):
method device_memory_size (line 90) | def device_memory_size(self):
FILE: src/flux/trt/trt_config/base_trt_config.py
class ModuleName (line 30) | class ModuleName(Enum):
class TRTBaseConfig (line 42) | class TRTBaseConfig:
method build_trt_engine (line 69) | def build_trt_engine(
method from_args (line 174) | def from_args(cls, model_name: str, *args, **kwargs) -> Any:
method get_input_profile (line 178) | def get_input_profile(
method check_dims (line 209) | def check_dims(self, *args, **kwargs) -> None | tuple[int, int] | int:
method _check_batch (line 213) | def _check_batch(self, batch_size):
method __post_init__ (line 218) | def __post_init__(self):
method _get_onnx_path (line 223) | def _get_onnx_path(self) -> str:
method _get_engine_path (line 232) | def _get_engine_path(self) -> str:
method _get_repo_id (line 240) | def _get_repo_id(model_name: str) -> str:
function register_config (line 255) | def register_config(module_name: ModuleName, precision: str):
function get_config (line 266) | def get_config(module_name: ModuleName, precision: str) -> TRTBaseConfig:
FILE: src/flux/trt/trt_config/clip_trt_config.py
class ClipConfig (line 25) | class ClipConfig(TRTBaseConfig):
method from_args (line 35) | def from_args(
method check_dims (line 48) | def check_dims(self, batch_size: int) -> None:
method get_input_profile (line 51) | def get_input_profile(
FILE: src/flux/trt/trt_config/t5_trt_config.py
class T5Config (line 29) | class T5Config(TRTBaseConfig):
method from_args (line 40) | def from_args(
method check_dims (line 53) | def check_dims(self, batch_size: int) -> None:
method get_input_profile (line 56) | def get_input_profile(
method _get_onnx_path (line 74) | def _get_onnx_path(self) -> str:
FILE: src/flux/trt/trt_config/transformer_trt_config.py
class TransformerConfig (line 31) | class TransformerConfig(TRTBaseConfig):
method from_args (line 57) | def from_args(
method _get_onnx_path (line 87) | def _get_onnx_path(self) -> str:
method _get_latent (line 99) | def _get_latent(image_dim: int, compression_factor: int) -> int:
method _get_context_dim (line 103) | def _get_context_dim(
method __post_init__ (line 118) | def __post_init__(self):
method get_minmax_dims (line 160) | def get_minmax_dims(
method check_dims (line 194) | def check_dims(
method get_input_profile (line 219) | def get_input_profile(
FILE: src/flux/trt/trt_config/vae_trt_config.py
class VAEBaseConfig (line 25) | class VAEBaseConfig(TRTBaseConfig):
method _get_latent_dim (line 38) | def _get_latent_dim(self, image_dim: int) -> int:
method __post_init__ (line 41) | def __post_init__(self):
method check_dims (line 46) | def check_dims(
class VAEDecoderConfig (line 71) | class VAEDecoderConfig(VAEBaseConfig):
method from_args (line 79) | def from_args(
method get_minmax_dims (line 102) | def get_minmax_dims(
method get_input_profile (line 128) | def get_input_profile(
class VAEEncoderConfig (line 175) | class VAEEncoderConfig(VAEBaseConfig):
method from_args (line 183) | def from_args(cls, model_name: str, **kwargs):
method get_minmax_dims (line 206) | def get_minmax_dims(
method get_input_profile (line 229) | def get_input_profile(
FILE: src/flux/trt/trt_manager.py
class TRTManager (line 46) | class TRTManager:
method module_to_engine_class (line 48) | def module_to_engine_class(self) -> dict[ModuleName, type[Engine]]:
method __init__ (line 57) | def __init__(
method _parse_models_precisions (line 76) | def _parse_models_precisions(
method _parse_custom_onnx_path (line 99) | def _parse_custom_onnx_path(custom_onnx_paths: str) -> dict[ModuleName...
method _create_directories (line 128) | def _create_directories(engine_dir: str):
method _get_trt_configs (line 132) | def _get_trt_configs(
method _build_engine (line 179) | def _build_engine(
method load_engines (line 211) | def load_engines(
method _clean_memory (line 279) | def _clean_memory():
method init_runtime (line 283) | def init_runtime(self):
method stop_runtime (line 290) | def stop_runtime(self):
FILE: src/flux/util.py
function ensure_hf_auth (line 27) | def ensure_hf_auth():
function prompt_for_hf_auth (line 45) | def prompt_for_hf_auth():
function get_checkpoint_path (line 64) | def get_checkpoint_path(repo_id: str, filename: str, env_var: str) -> Path:
function download_onnx_models_for_trt (line 108) | def download_onnx_models_for_trt(model_name: str, trt_transformer_precis...
function check_onnx_access_for_trt (line 201) | def check_onnx_access_for_trt(model_name: str, trt_transformer_precision...
function track_usage_via_api (line 206) | def track_usage_via_api(name: str, n=1) -> None:
function save_image (line 243) | def save_image(
class ModelSpec (line 288) | class ModelSpec:
function aspect_ratio_to_height_width (line 637) | def aspect_ratio_to_height_width(aspect_ratio: str, area: int = 1024**2)...
function print_load_warning (line 646) | def print_load_warning(missing: list[str], unexpected: list[str]) -> None:
function load_flow_model (line 657) | def load_flow_model(name: str, device: str | torch.device = "cuda", verb...
function load_t5 (line 689) | def load_t5(device: str | torch.device = "cuda", max_length: int = 512) ...
function load_clip (line 694) | def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
function load_ae (line 698) | def load_ae(name: str, device: str | torch.device = "cuda") -> AutoEncoder:
function optionally_expand_state_dict (line 714) | def optionally_expand_state_dict(model: torch.nn.Module, state_dict: dic...
class WatermarkEmbedder (line 733) | class WatermarkEmbedder:
method __init__ (line 734) | def __init__(self, watermark):
method __call__ (line 740) | def __call__(self, image: torch.Tensor) -> torch.Tensor:
Condensed preview — 51 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (342K chars).
[
{
"path": ".github/workflows/ci.yaml",
"chars": 545,
"preview": "name: CI\non: push\njobs:\n lint:\n runs-on: ubuntu-latest\n steps:\n - uses: actions/checkout@v2\n - uses: ac"
},
{
"path": ".gitignore",
"chars": 3706,
"preview": "# Created by https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python\n# Edit at https"
},
{
"path": "LICENSE",
"chars": 11357,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 6721,
"preview": "# FLUX\nby Black Forest Labs: https://bfl.ai.\n\nDocumentation for our API can be found here: [docs.bfl.ai](https://docs.bf"
},
{
"path": "demo_gr.py",
"chars": 9550,
"preview": "import os\nimport time\nimport uuid\n\nimport gradio as gr\nimport numpy as np\nimport torch\nfrom einops import rearrange\nfrom"
},
{
"path": "demo_st.py",
"chars": 9874,
"preview": "import os\nimport re\nimport time\nfrom glob import iglob\nfrom io import BytesIO\n\nimport streamlit as st\nimport torch\nfrom "
},
{
"path": "demo_st_fill.py",
"chars": 17944,
"preview": "import os\nimport re\nimport tempfile\nimport time\nfrom glob import iglob\nfrom io import BytesIO\n\nimport numpy as np\nimport"
},
{
"path": "docs/fill.md",
"chars": 1882,
"preview": "## Open-weight models\n\nFLUX.1 Fill introduces advanced inpainting and outpainting capabilities. It allows for seamless e"
},
{
"path": "docs/image-editing.md",
"chars": 2027,
"preview": "## Open-weight models\n\nWe currently offer two open-weight text-to-image models.\n\n| Name | HuggingFa"
},
{
"path": "docs/image-variation.md",
"chars": 1532,
"preview": "## Models\n\nFLUX.1 Redux is an adapter for the FLUX.1 text-to-image base models, FLUX.1 [dev] and FLUX.1 [schnell], which"
},
{
"path": "docs/structural-conditioning.md",
"chars": 3506,
"preview": "## Models\n\nStructural conditioning uses canny edge or depth detection to maintain precise control during image transform"
},
{
"path": "docs/text-to-image.md",
"chars": 4150,
"preview": "## Open-weight models\n\nWe currently offer two open-weight text-to-image models.\n\n| Name | HuggingFa"
},
{
"path": "model_cards/FLUX.1-Krea-dev.md",
"chars": 6477,
"preview": "![FLUX.1 Krea [dev] Grid](../assets/flux-1-krea-dev-grid.png)\n\n`FLUX.1 Krea [dev]` is a 12 billion parameter rectified f"
},
{
"path": "model_cards/FLUX.1-dev.md",
"chars": 2933,
"preview": "![FLUX.1 [dev] Grid](../assets/dev_grid.jpg)\n\n`FLUX.1 [dev]` is a 12 billion parameter rectified flow transformer capabl"
},
{
"path": "model_cards/FLUX.1-kontext-dev.md",
"chars": 7610,
"preview": "![FLUX.1 [dev] Grid](../assets/docs/kontext.png)\n\n`FLUX.1 Kontext [dev]` is a 12 billion parameter rectified flow transf"
},
{
"path": "model_cards/FLUX.1-schnell.md",
"chars": 2665,
"preview": "![FLUX.1 [schnell] Grid](../assets/schnell_grid.jpg)\n\n`FLUX.1 [schnell]` is a 12 billion parameter rectified flow transf"
},
{
"path": "model_licenses/LICENSE-FLUX1-dev",
"chars": 18491,
"preview": "FLUX.1 [dev] Non-Commercial License v1.1.1\n\nBlack Forest Labs Inc. (“we” or “our” or “Company”) is pleased to make avail"
},
{
"path": "model_licenses/LICENSE-FLUX1-schnell",
"chars": 9155,
"preview": "\n\nApache License\nVersion 2.0, January 2004\nhttp://www.apache.org/licenses/\n\nTERMS AND CONDITIONS FOR USE, REPRODUCTION, "
},
{
"path": "pyproject.toml",
"chars": 2239,
"preview": "[project]\nname = \"flux\"\nauthors = [\n { name = \"Black Forest Labs\", email = \"support@blackforestlabs.ai\" },\n]\ndescriptio"
},
{
"path": "setup.py",
"chars": 38,
"preview": "import setuptools\n\nsetuptools.setup()\n"
},
{
"path": "src/flux/__init__.py",
"chars": 345,
"preview": "try:\n from ._version import (\n version as __version__, # type: ignore\n version_tuple,\n )\nexcept Imp"
},
{
"path": "src/flux/__main__.py",
"chars": 462,
"preview": "from fire import Fire\n\nfrom .cli import main as cli_main\nfrom .cli_control import main as control_main\nfrom .cli_fill im"
},
{
"path": "src/flux/cli.py",
"chars": 10486,
"preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
},
{
"path": "src/flux/cli_control.py",
"chars": 13733,
"preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
},
{
"path": "src/flux/cli_fill.py",
"chars": 11209,
"preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
},
{
"path": "src/flux/cli_kontext.py",
"chars": 13165,
"preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
},
{
"path": "src/flux/cli_redux.py",
"chars": 9767,
"preview": "import os\nimport re\nimport time\nfrom dataclasses import dataclass\nfrom glob import iglob\n\nimport torch\nfrom fire import "
},
{
"path": "src/flux/content_filters.py",
"chars": 5959,
"preview": "import torch\nfrom einops import rearrange\nfrom PIL import Image\nfrom transformers import AutoProcessor, LlavaForConditio"
},
{
"path": "src/flux/math.py",
"chars": 1166,
"preview": "import torch\nfrom einops import rearrange\nfrom torch import Tensor\n\n\ndef attention(q: Tensor, k: Tensor, v: Tensor, pe: "
},
{
"path": "src/flux/model.py",
"chars": 4399,
"preview": "from dataclasses import dataclass\n\nimport torch\nfrom torch import Tensor, nn\n\nfrom flux.modules.layers import (\n Doub"
},
{
"path": "src/flux/modules/autoencoder.py",
"chars": 10652,
"preview": "from dataclasses import dataclass\n\nimport torch\nfrom einops import rearrange\nfrom torch import Tensor, nn\n\n\n@dataclass\nc"
},
{
"path": "src/flux/modules/conditioner.py",
"chars": 1502,
"preview": "from torch import Tensor, nn\nfrom transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5Tokenizer\n\n\nclass "
},
{
"path": "src/flux/modules/image_embedders.py",
"chars": 3465,
"preview": "import cv2\nimport numpy as np\nimport torch\nfrom einops import rearrange, repeat\nfrom PIL import Image\nfrom safetensors.t"
},
{
"path": "src/flux/modules/layers.py",
"chars": 9374,
"preview": "import math\nfrom dataclasses import dataclass\n\nimport torch\nfrom einops import rearrange\nfrom torch import Tensor, nn\n\nf"
},
{
"path": "src/flux/modules/lora.py",
"chars": 2545,
"preview": "import torch\nfrom torch import nn\n\n\ndef replace_linear_with_lora(\n module: nn.Module,\n max_rank: int,\n scale: f"
},
{
"path": "src/flux/sampling.py",
"chars": 11434,
"preview": "import math\nfrom typing import Callable\n\nimport numpy as np\nimport torch\nfrom einops import rearrange, repeat\nfrom PIL i"
},
{
"path": "src/flux/trt/__init__.py",
"chars": 127,
"preview": "from flux.trt.trt_config import ModuleName\nfrom flux.trt.trt_manager import TRTManager\n\n__all__ = [\"TRTManager\", \"Module"
},
{
"path": "src/flux/trt/engine/__init__.py",
"chars": 1175,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/engine/base_engine.py",
"chars": 11083,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/engine/clip_engine.py",
"chars": 1853,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/engine/t5_engine.py",
"chars": 1826,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/engine/transformer_engine.py",
"chars": 2491,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/engine/vae_engine.py",
"chars": 3163,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_config/__init__.py",
"chars": 1261,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_config/base_trt_config.py",
"chars": 11052,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_config/clip_trt_config.py",
"chars": 2105,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_config/t5_trt_config.py",
"chars": 2736,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_config/transformer_trt_config.py",
"chars": 10336,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_config/vae_trt_config.py",
"chars": 9481,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/trt/trt_manager.py",
"chars": 10243,
"preview": "#\n# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.\n# SPDX-License"
},
{
"path": "src/flux/util.py",
"chars": 26452,
"preview": "import getpass\nimport math\nimport os\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport requests\nimport "
}
]
About this extraction
This page contains the full source code of the black-forest-labs/flux GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 51 files (319.8 KB), approximately 78.9k tokens, and a symbol index with 266 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.